From dc2635c06f8e4dee994bb94e63c8e4d8a08faa94 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 25 Feb 2025 01:10:05 +0800 Subject: [PATCH 01/39] WIP --- .github/workflows/coverage.yml | 2 +- Cargo.toml | 34 +- {core => serf-core}/Cargo.toml | 12 +- {core => serf-core}/src/broadcast.rs | 2 +- {core => serf-core}/src/coalesce.rs | 8 +- {core => serf-core}/src/coalesce/member.rs | 10 +- {core => serf-core}/src/coalesce/user.rs | 6 +- {core => serf-core}/src/coordinate.rs | 2 +- {core => serf-core}/src/delegate.rs | 2 +- {core => serf-core}/src/delegate/composite.rs | 4 +- {core => serf-core}/src/delegate/merge.rs | 2 +- {core => serf-core}/src/delegate/reconnect.rs | 2 +- {core => serf-core}/src/delegate/transform.rs | 12 +- {core => serf-core}/src/error.rs | 0 {core => serf-core}/src/event.rs | 4 +- {core => serf-core}/src/event/crate_event.rs | 2 +- {core => serf-core}/src/key_manager.rs | 8 +- {core => serf-core}/src/lib.rs | 2 +- {core => serf-core}/src/options.rs | 0 {core => serf-core}/src/serf.rs | 6 +- {core => serf-core}/src/serf/api.rs | 2 +- {core => serf-core}/src/serf/base.rs | 15 +- {core => serf-core}/src/serf/base/tests.rs | 6 +- .../src/serf/base/tests/serf.rs | 4 +- .../src/serf/base/tests/serf/delegate.rs | 0 .../src/serf/base/tests/serf/event.rs | 2 +- .../src/serf/base/tests/serf/join.rs | 32 +- .../src/serf/base/tests/serf/leave.rs | 16 +- .../src/serf/base/tests/serf/reap.rs | 0 .../src/serf/base/tests/serf/reconnect.rs | 0 .../src/serf/base/tests/serf/remove.rs | 0 .../src/serf/base/tests/serf/snapshot.rs | 4 +- {core => serf-core}/src/serf/delegate.rs | 8 +- .../src/serf/internal_query.rs | 4 +- {core => serf-core}/src/serf/query.rs | 4 +- {core => serf-core}/src/snapshot.rs | 8 +- {core => serf-core}/src/types.rs | 2 +- {core => serf-core}/src/types/member.rs | 2 +- {core => serf-core}/src/types/message.rs | 0 {types => serf-proto}/Cargo.toml | 21 +- serf-proto/src/arbitrary_impl.rs | 73 ++ {types => serf-proto}/src/clock.rs | 62 +- serf-proto/src/filter.rs | 80 +++ serf-proto/src/join.rs | 46 ++ serf-proto/src/key.rs | 153 +++++ serf-proto/src/leave.rs | 33 + {types => serf-proto}/src/lib.rs | 12 +- serf-proto/src/member.rs | 187 ++++++ {types => serf-proto}/src/message.rs | 100 +-- serf-proto/src/push_pull.rs | 147 ++++ serf-proto/src/query.rs | 162 +++++ serf-proto/src/tags.rs | 58 ++ serf-proto/src/user_event.rs | 109 +++ serf-proto/src/version.rs | 149 +++++ serf/Cargo.toml | 4 +- serf/test/main.rs | 4 +- types/src/filter.rs | 298 --------- types/src/join.rs | 188 ------ types/src/key.rs | 617 ----------------- types/src/leave.rs | 178 ----- types/src/member.rs | 402 ----------- types/src/push_pull.rs | 467 ------------- types/src/query.rs | 631 ------------------ types/src/tags.rs | 202 ------ types/src/user_event.rs | 525 --------------- types/src/version.rs | 152 ----- 66 files changed, 1361 insertions(+), 3928 deletions(-) rename {core => serf-core}/Cargo.toml (85%) rename {core => serf-core}/src/broadcast.rs (96%) rename {core => serf-core}/src/coalesce.rs (95%) rename {core => serf-core}/src/coalesce/member.rs (98%) rename {core => serf-core}/src/coalesce/user.rs (98%) rename {core => serf-core}/src/coordinate.rs (99%) rename {core => serf-core}/src/delegate.rs (94%) rename {core => serf-core}/src/delegate/composite.rs (99%) rename {core => serf-core}/src/delegate/merge.rs (96%) rename {core => serf-core}/src/delegate/reconnect.rs (96%) rename {core => serf-core}/src/delegate/transform.rs (97%) rename {core => serf-core}/src/error.rs (100%) rename {core => serf-core}/src/event.rs (99%) rename {core => serf-core}/src/event/crate_event.rs (99%) rename {core => serf-core}/src/key_manager.rs (98%) rename {core => serf-core}/src/lib.rs (97%) rename {core => serf-core}/src/options.rs (100%) rename {core => serf-core}/src/serf.rs (99%) rename {core => serf-core}/src/serf/api.rs (100%) rename {core => serf-core}/src/serf/base.rs (99%) rename {core => serf-core}/src/serf/base/tests.rs (99%) rename {core => serf-core}/src/serf/base/tests/serf.rs (99%) rename {core => serf-core}/src/serf/base/tests/serf/delegate.rs (100%) rename {core => serf-core}/src/serf/base/tests/serf/event.rs (99%) rename {core => serf-core}/src/serf/base/tests/serf/join.rs (93%) rename {core => serf-core}/src/serf/base/tests/serf/leave.rs (96%) rename {core => serf-core}/src/serf/base/tests/serf/reap.rs (100%) rename {core => serf-core}/src/serf/base/tests/serf/reconnect.rs (100%) rename {core => serf-core}/src/serf/base/tests/serf/remove.rs (100%) rename {core => serf-core}/src/serf/base/tests/serf/snapshot.rs (99%) rename {core => serf-core}/src/serf/delegate.rs (99%) rename {core => serf-core}/src/serf/internal_query.rs (99%) rename {core => serf-core}/src/serf/query.rs (99%) rename {core => serf-core}/src/snapshot.rs (99%) rename {core => serf-core}/src/types.rs (99%) rename {core => serf-core}/src/types/member.rs (98%) rename {core => serf-core}/src/types/message.rs (100%) rename {types => serf-proto}/Cargo.toml (54%) create mode 100644 serf-proto/src/arbitrary_impl.rs rename {types => serf-proto}/src/clock.rs (80%) create mode 100644 serf-proto/src/filter.rs create mode 100644 serf-proto/src/join.rs create mode 100644 serf-proto/src/key.rs create mode 100644 serf-proto/src/leave.rs rename {types => serf-proto}/src/lib.rs (69%) create mode 100644 serf-proto/src/member.rs rename {types => serf-proto}/src/message.rs (74%) create mode 100644 serf-proto/src/push_pull.rs create mode 100644 serf-proto/src/query.rs create mode 100644 serf-proto/src/tags.rs create mode 100644 serf-proto/src/user_event.rs create mode 100644 serf-proto/src/version.rs delete mode 100644 types/src/filter.rs delete mode 100644 types/src/join.rs delete mode 100644 types/src/key.rs delete mode 100644 types/src/leave.rs delete mode 100644 types/src/member.rs delete mode 100644 types/src/push_pull.rs delete mode 100644 types/src/query.rs delete mode 100644 types/src/tags.rs delete mode 100644 types/src/user_event.rs delete mode 100644 types/src/version.rs diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 2efdb05..78662dc 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -36,7 +36,7 @@ jobs: name: "serf-core" - crate: types features: "metrics,encryption" - name: "serf-types" + name: "serf-proto" - crate: serf features: "test,tokio,tcp,encryption,metrics" name: "serf-tcp-encryption" diff --git a/Cargo.toml b/Cargo.toml index 814f205..28ba986 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,15 @@ [workspace] members = [ - "core", "serf", - "types" + "serf-core", + "serf-proto" ] -resolver = "2" +resolver = "3" [workspace.package] -version = "0.2.0" -edition = "2021" -rust-version = "1.81.0" +version = "0.3.0" +edition = "2024" +rust-version = "1.85.0" description = "A decentralized solution for service discovery and orchestration that is lightweight, highly available, and fault tolerant." repository = "https://github.com/al8n/serf" homepage = "https://github.com/al8n/serf" @@ -20,23 +20,23 @@ readme = "README.md" [workspace.dependencies] byteorder = "1" -derive_more = { version = "1", features = ["full"] } +derive_more = { version = "2", features = ["full"] } futures = { version = "0.3", default-features = false } serde = { version = "1", features = ["derive"] } humantime-serde = "1" indexmap = "2" -memberlist-types = { version = "0.3", default-features = false } -memberlist-core = { version = "0.3", default-features = false } -memberlist = { version = "0.3", default-features = false } -thiserror = "2" +# memberlist-proto = { version = "0.3", default-features = false } +# memberlist-core = { version = "0.3", default-features = false } +# memberlist = { version = "0.3", default-features = false } +thiserror = { version = "2", default-features = false } viewit = "0.1.5" smol_str = "0.3" smallvec = "1" -rand = "0.8" +rand = "0.9" -# memberlist-types = { version = "0.3", path = "../memberlist/types", default-features = false } -# memberlist-core = { version = "0.3", path = "../memberlist/core", default-features = false } -# memberlist = { version = "0.3", path = "../memberlist/memberlist", default-features = false } +memberlist-proto = { version = "0.1", path = "../memberlist/memberlist-proto", default-features = false } +memberlist-core = { version = "0.6", path = "../memberlist/memberlist-core", default-features = false } +memberlist = { version = "0.6", path = "../memberlist/memberlist", default-features = false } -serf-core = { path = "core", version = "0.2.0", default-features = false } -serf-types = { path = "types", version = "0.2.0", default-features = false } +serf-core = { path = "serf-core", version = "0.3.0", default-features = false } +serf-proto = { path = "serf-proto", version = "0.1.0", default-features = false } diff --git a/core/Cargo.toml b/serf-core/Cargo.toml similarity index 85% rename from core/Cargo.toml rename to serf-core/Cargo.toml index c37200a..584cca9 100644 --- a/core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -13,15 +13,14 @@ categories.workspace = true [features] default = ["metrics"] -metrics = ["memberlist-core/metrics", "dep:metrics", "serf-types/metrics"] -encryption = ["memberlist-core/encryption", "serf-types/encryption", "base64", "serde"] -async-graphql = ["dep:async-graphql"] +metrics = ["memberlist-core/metrics", "dep:metrics", "serf-proto/metrics"] +encryption = ["memberlist-core/encryption", "serf-proto/encryption", "base64", "serde"] serde = [ "dep:serde", "dep:humantime-serde", "memberlist-core/serde", - "serf-types/serde", + "serf-proto/serde", "smol_str/serde", "smallvec/serde", "indexmap/serde", @@ -35,7 +34,6 @@ atomic_refcell = "0.1" arc-swap = "1" async-lock = "3" async-channel = "2" -async-graphql = { version = "7", optional = true } byteorder.workspace = true crossbeam-queue = "0.3" derive_more.workspace = true @@ -53,7 +51,7 @@ smallvec.workspace = true thiserror.workspace = true viewit.workspace = true memberlist-core.workspace = true -serf-types.workspace = true +serf-proto.workspace = true metrics = { version = "0.24", optional = true } @@ -72,7 +70,7 @@ tracing-subscriber = { version = "0.3", optional = true, features = [ tempfile = { version = "3", optional = true } [dev-dependencies] -agnostic-lite = { version = "0.3", features = ["tokio"] } +agnostic-lite = { version = "0.5", features = ["tokio"] } tokio = { version = "1", features = ["full"] } futures = { workspace = true, features = ["executor"] } tempfile = "3" diff --git a/core/src/broadcast.rs b/serf-core/src/broadcast.rs similarity index 96% rename from core/src/broadcast.rs rename to serf-core/src/broadcast.rs index c6a0a76..d9a1728 100644 --- a/core/src/broadcast.rs +++ b/serf-core/src/broadcast.rs @@ -1,5 +1,5 @@ use async_channel::Sender; -use memberlist_core::{bytes::Bytes, Broadcast}; +use memberlist_core::{Broadcast, bytes::Bytes}; #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub(crate) struct BroadcastId; diff --git a/core/src/coalesce.rs b/serf-core/src/coalesce.rs similarity index 95% rename from core/src/coalesce.rs rename to serf-core/src/coalesce.rs index 439d4d8..64660b7 100644 --- a/core/src/coalesce.rs +++ b/serf-core/src/coalesce.rs @@ -5,7 +5,7 @@ pub(crate) use user::*; use std::{future::Future, time::Duration}; -use async_channel::{bounded, Receiver, Sender}; +use async_channel::{Receiver, Sender, bounded}; use futures::FutureExt; use memberlist_core::{ agnostic_lite::RuntimeLite, @@ -21,9 +21,9 @@ pub(crate) struct ClosedOutChannel; pub(crate) trait Coalescer: Send + Sync + 'static { type Delegate: Delegate< - Id = ::Id, - Address = <::Resolver as AddressResolver>::ResolvedAddress, - >; + Id = ::Id, + Address = <::Resolver as AddressResolver>::ResolvedAddress, + >; type Transport: Transport; fn name(&self) -> &'static str; diff --git a/core/src/coalesce/member.rs b/serf-core/src/coalesce/member.rs similarity index 98% rename from core/src/coalesce/member.rs rename to serf-core/src/coalesce/member.rs index 12fed24..54b356d 100644 --- a/core/src/coalesce/member.rs +++ b/serf-core/src/coalesce/member.rs @@ -2,9 +2,9 @@ use std::{collections::HashMap, marker::PhantomData}; use async_channel::Sender; use memberlist_core::{ + CheapClone, transport::{AddressResolver, Node, Transport}, types::TinyVec, - CheapClone, }; use crate::{ @@ -125,16 +125,16 @@ mod tests { use futures::FutureExt; use memberlist_core::{ - agnostic_lite::{tokio::TokioRuntime, RuntimeLite}, - transport::{resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport, Lpe}, + agnostic_lite::{RuntimeLite, tokio::TokioRuntime}, + transport::{Lpe, resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport}, }; - use serf_types::{MemberStatus, UserEventMessage}; + use serf_proto::{MemberStatus, UserEventMessage}; use smol_str::SmolStr; use crate::{ + DefaultDelegate, coalesce::coalesced_event, event::{CrateEventType, MemberEvent}, - DefaultDelegate, }; use super::*; diff --git a/core/src/coalesce/user.rs b/serf-core/src/coalesce/user.rs similarity index 98% rename from core/src/coalesce/user.rs rename to serf-core/src/coalesce/user.rs index 88bb8b3..de82e50 100644 --- a/core/src/coalesce/user.rs +++ b/serf-core/src/coalesce/user.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use indexmap::IndexMap; use memberlist_core::types::TinyVec; -use serf_types::UserEventMessage; +use serf_proto::UserEventMessage; use smol_str::SmolStr; use crate::types::LamportTime; @@ -103,12 +103,12 @@ mod tests { use agnostic_lite::tokio::TokioRuntime; use memberlist_core::transport::{ - resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport, Lpe, + Lpe, resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport, }; use crate::{ - event::{MemberEvent, MemberEventType}, DefaultDelegate, + event::{MemberEvent, MemberEventType}, }; use super::*; diff --git a/core/src/coordinate.rs b/serf-core/src/coordinate.rs similarity index 99% rename from core/src/coordinate.rs rename to serf-core/src/coordinate.rs index 0018dcd..096fc9b 100644 --- a/core/src/coordinate.rs +++ b/serf-core/src/coordinate.rs @@ -8,7 +8,7 @@ use byteorder::{ByteOrder, NetworkEndian}; use memberlist_core::CheapClone; use parking_lot::RwLock; use rand::Rng; -use serf_types::Transformable; +use serf_proto::Transformable; use smallvec::SmallVec; /// Used to convert float seconds to nanoseconds. diff --git a/core/src/delegate.rs b/serf-core/src/delegate.rs similarity index 94% rename from core/src/delegate.rs rename to serf-core/src/delegate.rs index 175d928..48288fc 100644 --- a/core/src/delegate.rs +++ b/serf-core/src/delegate.rs @@ -1,4 +1,4 @@ -use memberlist_core::{transport::Id, CheapClone}; +use memberlist_core::{CheapClone, transport::Id}; mod merge; pub use merge::*; diff --git a/core/src/delegate/composite.rs b/serf-core/src/delegate/composite.rs similarity index 99% rename from core/src/delegate/composite.rs rename to serf-core/src/delegate/composite.rs index f57d408..e44e81c 100644 --- a/core/src/delegate/composite.rs +++ b/serf-core/src/delegate/composite.rs @@ -1,9 +1,9 @@ use memberlist_core::{ + CheapClone, transport::{Id, Node}, types::TinyVec, - CheapClone, }; -use serf_types::MessageType; +use serf_proto::MessageType; use crate::{ coordinate::Coordinate, diff --git a/core/src/delegate/merge.rs b/serf-core/src/delegate/merge.rs similarity index 96% rename from core/src/delegate/merge.rs rename to serf-core/src/delegate/merge.rs index e00f1d0..004edde 100644 --- a/core/src/delegate/merge.rs +++ b/serf-core/src/delegate/merge.rs @@ -1,4 +1,4 @@ -use memberlist_core::{transport::Id, types::TinyVec, CheapClone}; +use memberlist_core::{CheapClone, transport::Id, types::TinyVec}; use std::future::Future; use crate::types::Member; diff --git a/core/src/delegate/reconnect.rs b/serf-core/src/delegate/reconnect.rs similarity index 96% rename from core/src/delegate/reconnect.rs rename to serf-core/src/delegate/reconnect.rs index 7dc6600..4073052 100644 --- a/core/src/delegate/reconnect.rs +++ b/serf-core/src/delegate/reconnect.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use memberlist_core::{transport::Id, CheapClone}; +use memberlist_core::{CheapClone, transport::Id}; use crate::types::Member; diff --git a/core/src/delegate/transform.rs b/serf-core/src/delegate/transform.rs similarity index 97% rename from core/src/delegate/transform.rs rename to serf-core/src/delegate/transform.rs index f628d53..d357466 100644 --- a/core/src/delegate/transform.rs +++ b/serf-core/src/delegate/transform.rs @@ -1,9 +1,9 @@ use memberlist_core::{ + CheapClone, bytes::Bytes, transport::{Id, Node, Transformable}, - CheapClone, }; -use serf_types::{ +use serf_proto::{ FilterTransformError, JoinMessage, LeaveMessage, Member, MessageType, NodeTransformError, PushPullMessage, QueryMessage, QueryResponseMessage, SerfMessageTransformError, TagsTransformError, UserEventMessage, @@ -252,7 +252,7 @@ where fn message_encoded_len(msg: impl AsMessageRef) -> usize { let msg = msg.as_message_ref(); - serf_types::Encodable::encoded_len(&msg) + serf_proto::Encodable::encoded_len(&msg) } fn encode_message( @@ -260,7 +260,7 @@ where mut dst: impl AsMut<[u8]>, ) -> Result { let msg = msg.as_message_ref(); - serf_types::Encodable::encode(&msg, dst.as_mut()).map_err(Into::into) + serf_proto::Encodable::encode(&msg, dst.as_mut()).map_err(Into::into) } fn decode_message( @@ -291,11 +291,11 @@ where .map_err(|e| Self::Error::Message(e.into())), MessageType::Relay => Err(Self::Error::UnexpectedRelayMessage), #[cfg(feature = "encryption")] - MessageType::KeyRequest => serf_types::KeyRequestMessage::decode(bytes.as_ref()) + MessageType::KeyRequest => serf_proto::KeyRequestMessage::decode(bytes.as_ref()) .map(|(n, m)| (n, SerfMessage::KeyRequest(m))) .map_err(|e| Self::Error::Message(e.into())), #[cfg(feature = "encryption")] - MessageType::KeyResponse => serf_types::KeyResponseMessage::decode(bytes.as_ref()) + MessageType::KeyResponse => serf_proto::KeyResponseMessage::decode(bytes.as_ref()) .map(|(n, m)| (n, SerfMessage::KeyResponse(m))) .map_err(|e| Self::Error::Message(e.into())), _ => unreachable!(), diff --git a/core/src/error.rs b/serf-core/src/error.rs similarity index 100% rename from core/src/error.rs rename to serf-core/src/error.rs diff --git a/core/src/event.rs b/serf-core/src/event.rs similarity index 99% rename from core/src/event.rs rename to serf-core/src/event.rs index fcc5017..a0d6384 100644 --- a/core/src/event.rs +++ b/serf-core/src/event.rs @@ -15,12 +15,12 @@ use async_lock::Mutex; pub(crate) use crate_event::*; use futures::Stream; use memberlist_core::{ + CheapClone, bytes::{BufMut, Bytes, BytesMut}, transport::{AddressResolver, Transport}, types::TinyVec, - CheapClone, }; -use serf_types::{ +use serf_proto::{ LamportTime, Member, MessageType, Node, QueryFlag, QueryResponseMessage, UserEventMessage, }; use smol_str::SmolStr; diff --git a/core/src/event/crate_event.rs b/serf-core/src/event/crate_event.rs similarity index 99% rename from core/src/event/crate_event.rs rename to serf-core/src/event/crate_event.rs index 4be8dcc..2b7ba33 100644 --- a/core/src/event/crate_event.rs +++ b/serf-core/src/event/crate_event.rs @@ -1,4 +1,4 @@ -use serf_types::QueryMessage; +use serf_proto::QueryMessage; use super::*; diff --git a/core/src/key_manager.rs b/serf-core/src/key_manager.rs similarity index 98% rename from core/src/key_manager.rs rename to serf-core/src/key_manager.rs index 38d031b..3ee43c4 100644 --- a/core/src/key_manager.rs +++ b/serf-core/src/key_manager.rs @@ -4,25 +4,25 @@ use async_channel::Receiver; use async_lock::RwLock; use futures::StreamExt; use memberlist_core::{ + CheapClone, bytes::{BufMut, BytesMut}, tracing, transport::{AddressResolver, Transport}, types::SecretKey, - CheapClone, }; use smol_str::SmolStr; use crate::event::{ - InternalQueryEvent, INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY, - INTERNAL_USE_KEY, + INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY, INTERNAL_USE_KEY, + InternalQueryEvent, }; use super::{ + Serf, delegate::{Delegate, TransformDelegate}, error::Error, serf::{NodeResponse, QueryResponse}, types::{KeyRequestMessage, MessageType, SerfMessage}, - Serf, }; /// KeyResponse is used to relay a query for a list of all keys in use. diff --git a/core/src/lib.rs b/serf-core/src/lib.rs similarity index 97% rename from core/src/lib.rs rename to serf-core/src/lib.rs index c86944e..cc9cba2 100644 --- a/core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -49,7 +49,7 @@ fn invalid_data_io_error(e: E) -> #[cfg(feature = "test")] #[cfg_attr(docsrs, doc(cfg(feature = "test")))] pub mod tests { - pub use memberlist_core::tests::{next_socket_addr_v4, next_socket_addr_v6, AnyError}; + pub use memberlist_core::tests::{AnyError, next_socket_addr_v4, next_socket_addr_v6}; pub use paste; pub use super::serf::base::tests::{serf::*, *}; diff --git a/core/src/options.rs b/serf-core/src/options.rs similarity index 100% rename from core/src/options.rs rename to serf-core/src/options.rs diff --git a/core/src/serf.rs b/serf-core/src/serf.rs similarity index 99% rename from core/src/serf.rs rename to serf-core/src/serf.rs index 1265114..d392fce 100644 --- a/core/src/serf.rs +++ b/serf-core/src/serf.rs @@ -1,27 +1,27 @@ use std::{ collections::HashMap, - sync::{atomic::AtomicBool, Arc}, + sync::{Arc, atomic::AtomicBool}, }; use async_lock::{Mutex, RwLock}; use atomic_refcell::AtomicRefCell; use futures::stream::FuturesUnordered; use memberlist_core::{ + Memberlist, agnostic_lite::{AsyncSpawner, RuntimeLite}, queue::TransmitLimitedQueue, transport::{AddressResolver, Transport}, types::MediumVec, - Memberlist, }; use super::{ + Options, broadcast::SerfBroadcast, coordinate::{Coordinate, CoordinateClient}, delegate::{CompositeDelegate, Delegate}, event::CrateEvent, snapshot::SnapshotHandle, types::{LamportClock, LamportTime, Members, UserEvents}, - Options, }; mod api; diff --git a/core/src/serf/api.rs b/serf-core/src/serf/api.rs similarity index 100% rename from core/src/serf/api.rs rename to serf-core/src/serf/api.rs index b425f9d..98f3ad0 100644 --- a/core/src/serf/api.rs +++ b/serf-core/src/serf/api.rs @@ -2,11 +2,11 @@ use std::sync::atomic::Ordering; use futures::{FutureExt, StreamExt}; use memberlist_core::{ + CheapClone, bytes::{BufMut, Bytes, BytesMut}, tracing, transport::{MaybeResolvedAddress, Node}, types::{Meta, OneOrMore, SmallVec}, - CheapClone, }; use smol_str::SmolStr; diff --git a/core/src/serf/base.rs b/serf-core/src/serf/base.rs similarity index 99% rename from core/src/serf/base.rs rename to serf-core/src/serf/base.rs index d17e1dd..3959819 100644 --- a/core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -2,30 +2,30 @@ use std::time::Duration; use futures::{FutureExt, StreamExt}; use memberlist_core::{ + CheapClone, agnostic_lite::Detach, bytes::{BufMut, Bytes, BytesMut}, delegate::EventDelegate, tracing, transport::{MaybeResolvedAddress, Node}, types::{Meta, NodeState, OneOrMore, TinyVec}, - CheapClone, }; use rand::{Rng, SeedableRng}; use smol_str::SmolStr; use crate::{ - coalesce::{coalesced_event, MemberEventCoalescer, UserEventCoalescer}, + QueueOptions, + coalesce::{MemberEventCoalescer, UserEventCoalescer, coalesced_event}, coordinate::CoordinateOptions, delegate::TransformDelegate, error::Error, event::{InternalQueryEvent, MemberEvent, MemberEventType, QueryContext, QueryEvent}, - snapshot::{open_and_replay_snapshot, Snapshot}, + snapshot::{Snapshot, open_and_replay_snapshot}, types::{ DelegateVersion, Epoch, JoinMessage, LeaveMessage, Member, MemberState, MemberStatus, MemberlistDelegateVersion, MemberlistProtocolVersion, MessageType, NodeIntent, ProtocolVersion, QueryFlag, QueryMessage, QueryResponseMessage, SerfMessage, UserEvent, UserEventMessage, }, - QueueOptions, }; use self::internal_query::SerfQueries; @@ -414,7 +414,7 @@ where /// Serialize the current keyring and save it to a file. #[cfg(feature = "encryption")] pub(crate) async fn write_keyring_file(&self) -> std::io::Result<()> { - use base64::{engine::general_purpose, Engine as _}; + use base64::{Engine as _, engine::general_purpose}; let Some(path) = self.inner.opts.keyring_file() else { return Ok(()); @@ -1749,10 +1749,7 @@ where continue; } - tracing::info!( - "serf: attempting re-join to previously known node {}", - prev - ); + tracing::info!("serf: attempting re-join to previously known node {}", prev); if let Err(e) = memberlist.join(prev.cheap_clone()).await { tracing::warn!( "serf: failed to re-join to previously known node {}: {}", diff --git a/core/src/serf/base/tests.rs b/serf-core/src/serf/base/tests.rs similarity index 99% rename from core/src/serf/base/tests.rs rename to serf-core/src/serf/base/tests.rs index 2d3d1f7..4c0459d 100644 --- a/core/src/serf/base/tests.rs +++ b/serf-core/src/serf/base/tests.rs @@ -8,7 +8,7 @@ use memberlist_core::{ transport::MaybeResolvedAddress, types::{OneOrMore, TinyVec}, }; -use serf_types::{ +use serf_proto::{ MessageType, Node, PushPullMessage, QueryFlag, QueryMessage, SerfMessage, UserEvent, UserEventMessage, }; @@ -340,7 +340,7 @@ pub async fn estimate_max_keys_in_list_key_response_factor( T: Transport, { use memberlist_core::types::SecretKey; - use serf_types::KeyResponseMessage; + use serf_proto::KeyResponseMessage; let size_limit = opts.query_response_size_limit() * 10; let opts = opts.with_query_response_size_limit(size_limit); @@ -400,7 +400,7 @@ where T: Transport, { use memberlist_core::types::SecretKey; - use serf_types::{Encodable, KeyResponseMessage}; + use serf_proto::{Encodable, KeyResponseMessage}; let opts = opts.with_query_response_size_limit(1024); let s = Serf::::new(transport_opts, opts).await.unwrap(); diff --git a/core/src/serf/base/tests/serf.rs b/serf-core/src/serf/base/tests/serf.rs similarity index 99% rename from core/src/serf/base/tests/serf.rs rename to serf-core/src/serf/base/tests/serf.rs index 1b12f45..c68ef27 100644 --- a/core/src/serf/base/tests/serf.rs +++ b/serf-core/src/serf/base/tests/serf.rs @@ -2,7 +2,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use memberlist_core::{tests::AnyError, transport::Id}; -use serf_types::{Member, MemberStatus, Tags}; +use serf_proto::{Member, MemberStatus, Tags}; use crate::{event::EventProducer, types::MemberState}; @@ -788,7 +788,7 @@ pub async fn serf_write_keyring_file( { use std::io::Read; - use base64::{engine::general_purpose, Engine as _}; + use base64::{Engine as _, engine::general_purpose}; const EXISTING: &str = "T9jncgl9mbLus+baTTa7q7nPSUrXwbDi2dhbtqir37s="; const NEW_KEY: &str = "HvY8ubRZMgafUOWvrOadwOckVa1wN3QWAo46FVKbVN8="; diff --git a/core/src/serf/base/tests/serf/delegate.rs b/serf-core/src/serf/base/tests/serf/delegate.rs similarity index 100% rename from core/src/serf/base/tests/serf/delegate.rs rename to serf-core/src/serf/base/tests/serf/delegate.rs diff --git a/core/src/serf/base/tests/serf/event.rs b/serf-core/src/serf/base/tests/serf/event.rs similarity index 99% rename from core/src/serf/base/tests/serf/event.rs rename to serf-core/src/serf/base/tests/serf/event.rs index 0a89095..7f4ddbe 100644 --- a/core/src/serf/base/tests/serf/event.rs +++ b/serf-core/src/serf/base/tests/serf/event.rs @@ -1,4 +1,4 @@ -use serf_types::{Filter, FilterType}; +use serf_proto::{Filter, FilterType}; use super::*; diff --git a/core/src/serf/base/tests/serf/join.rs b/serf-core/src/serf/base/tests/serf/join.rs similarity index 93% rename from core/src/serf/base/tests/serf/join.rs rename to serf-core/src/serf/base/tests/serf/join.rs index 535fb3d..4ee840b 100644 --- a/core/src/serf/base/tests/serf/join.rs +++ b/serf-core/src/serf/base/tests/serf/join.rs @@ -53,10 +53,10 @@ pub async fn join_intent_old_message( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Alive, - memberlist_protocol_version: serf_types::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_types::MemberlistDelegateVersion::V1, - protocol_version: serf_types::ProtocolVersion::V1, - delegate_version: serf_types::DelegateVersion::V1, + memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, + memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, + protocol_version: serf_proto::ProtocolVersion::V1, + delegate_version: serf_proto::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, @@ -104,10 +104,10 @@ pub async fn join_intent_newer( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Alive, - memberlist_protocol_version: serf_types::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_types::MemberlistDelegateVersion::V1, - protocol_version: serf_types::ProtocolVersion::V1, - delegate_version: serf_types::DelegateVersion::V1, + memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, + memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, + protocol_version: serf_proto::ProtocolVersion::V1, + delegate_version: serf_proto::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, @@ -156,10 +156,10 @@ pub async fn join_intent_reset_leaving( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Leaving, - memberlist_protocol_version: serf_types::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_types::MemberlistDelegateVersion::V1, - protocol_version: serf_types::ProtocolVersion::V1, - delegate_version: serf_types::DelegateVersion::V1, + memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, + memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, + protocol_version: serf_proto::ProtocolVersion::V1, + delegate_version: serf_proto::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, @@ -294,8 +294,8 @@ pub async fn join_pending_intent( addr, meta: Meta::empty(), state: memberlist_core::types::State::Alive, - protocol_version: serf_types::MemberlistProtocolVersion::V1, - delegate_version: serf_types::MemberlistDelegateVersion::V1, + protocol_version: serf_proto::MemberlistProtocolVersion::V1, + delegate_version: serf_proto::MemberlistDelegateVersion::V1, })) .await; @@ -341,8 +341,8 @@ pub async fn join_pending_intents( addr, meta: Meta::empty(), state: memberlist_core::types::State::Alive, - protocol_version: serf_types::MemberlistProtocolVersion::V1, - delegate_version: serf_types::MemberlistDelegateVersion::V1, + protocol_version: serf_proto::MemberlistProtocolVersion::V1, + delegate_version: serf_proto::MemberlistDelegateVersion::V1, })) .await; diff --git a/core/src/serf/base/tests/serf/leave.rs b/serf-core/src/serf/base/tests/serf/leave.rs similarity index 96% rename from core/src/serf/base/tests/serf/leave.rs rename to serf-core/src/serf/base/tests/serf/leave.rs index f1f8e82..aa705d9 100644 --- a/core/src/serf/base/tests/serf/leave.rs +++ b/serf-core/src/serf/base/tests/serf/leave.rs @@ -50,10 +50,10 @@ pub async fn leave_intent_old_message( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Alive, - memberlist_protocol_version: serf_types::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_types::MemberlistDelegateVersion::V1, - protocol_version: serf_types::ProtocolVersion::V1, - delegate_version: serf_types::DelegateVersion::V1, + memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, + memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, + protocol_version: serf_proto::ProtocolVersion::V1, + delegate_version: serf_proto::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, @@ -101,10 +101,10 @@ pub async fn leave_intent_newer( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Alive, - memberlist_protocol_version: serf_types::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_types::MemberlistDelegateVersion::V1, - protocol_version: serf_types::ProtocolVersion::V1, - delegate_version: serf_types::DelegateVersion::V1, + memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, + memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, + protocol_version: serf_proto::ProtocolVersion::V1, + delegate_version: serf_proto::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, diff --git a/core/src/serf/base/tests/serf/reap.rs b/serf-core/src/serf/base/tests/serf/reap.rs similarity index 100% rename from core/src/serf/base/tests/serf/reap.rs rename to serf-core/src/serf/base/tests/serf/reap.rs diff --git a/core/src/serf/base/tests/serf/reconnect.rs b/serf-core/src/serf/base/tests/serf/reconnect.rs similarity index 100% rename from core/src/serf/base/tests/serf/reconnect.rs rename to serf-core/src/serf/base/tests/serf/reconnect.rs diff --git a/core/src/serf/base/tests/serf/remove.rs b/serf-core/src/serf/base/tests/serf/remove.rs similarity index 100% rename from core/src/serf/base/tests/serf/remove.rs rename to serf-core/src/serf/base/tests/serf/remove.rs diff --git a/core/src/serf/base/tests/serf/snapshot.rs b/serf-core/src/serf/base/tests/serf/snapshot.rs similarity index 99% rename from core/src/serf/base/tests/serf/snapshot.rs rename to serf-core/src/serf/base/tests/serf/snapshot.rs index e4a69f4..4b38230 100644 --- a/core/src/serf/base/tests/serf/snapshot.rs +++ b/serf-core/src/serf/base/tests/serf/snapshot.rs @@ -615,7 +615,7 @@ pub async fn serf_snapshot_recovery( async fn test_snapshoter_slow_disk_not_blocking_event_tx() { use memberlist_core::{ agnostic_lite::tokio::TokioRuntime, - transport::{resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport, Lpe}, + transport::{Lpe, resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport}, }; use std::net::SocketAddr; @@ -732,7 +732,7 @@ async fn test_snapshoter_slow_disk_not_blocking_event_tx() { async fn test_snapshoter_slow_disk_not_blocking_memberlist() { use memberlist_core::{ agnostic_lite::tokio::TokioRuntime, - transport::{resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport, Lpe}, + transport::{Lpe, resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport}, }; use std::net::SocketAddr; diff --git a/core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs similarity index 99% rename from core/src/serf/delegate.rs rename to serf-core/src/serf/delegate.rs index 803614d..9db1949 100644 --- a/core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -1,4 +1,5 @@ use crate::{ + Serf, broadcast::SerfBroadcast, delegate::{Delegate, TransformDelegate}, error::{SerfDelegateError, SerfError}, @@ -8,14 +9,14 @@ use crate::{ MemberlistDelegateVersion, MemberlistProtocolVersion, MessageType, ProtocolVersion, PushPullMessageRef, SerfMessage, UserEventMessage, }, - Serf, }; -use std::sync::{atomic::Ordering, Arc, OnceLock}; +use std::sync::{Arc, OnceLock, atomic::Ordering}; use arc_swap::ArcSwap; use indexmap::IndexSet; use memberlist_core::{ + CheapClone, META_MAX_SIZE, bytes::{Buf, BufMut, Bytes, BytesMut}, delegate::{ AliveDelegate, ConflictDelegate, Delegate as MemberlistDelegate, EventDelegate, @@ -24,9 +25,8 @@ use memberlist_core::{ tracing, transport::{AddressResolver, Transport}, types::{Meta, NodeState, SmallVec, State, TinyVec}, - CheapClone, META_MAX_SIZE, }; -use serf_types::Tags; +use serf_proto::Tags; // PingVersion is an internal version for the ping message, above the normal // versioning we get from the protocol version. This enables small updates diff --git a/core/src/serf/internal_query.rs b/serf-core/src/serf/internal_query.rs similarity index 99% rename from core/src/serf/internal_query.rs rename to serf-core/src/serf/internal_query.rs index fbc8b9e..5503ce6 100644 --- a/core/src/serf/internal_query.rs +++ b/serf-core/src/serf/internal_query.rs @@ -1,4 +1,4 @@ -use async_channel::{bounded, Receiver, Sender}; +use async_channel::{Receiver, Sender, bounded}; use futures::FutureExt; use memberlist_core::{ agnostic_lite::{AsyncSpawner, RuntimeLite}, @@ -438,7 +438,7 @@ where ) -> Result< ( Bytes, - serf_types::QueryResponseMessage::ResolvedAddress>, + serf_proto::QueryResponseMessage::ResolvedAddress>, ), Error, > { diff --git a/core/src/serf/query.rs b/serf-core/src/serf/query.rs similarity index 99% rename from core/src/serf/query.rs rename to serf-core/src/serf/query.rs index d3693fd..e7b218e 100644 --- a/core/src/serf/query.rs +++ b/serf-core/src/serf/query.rs @@ -6,13 +6,13 @@ use std::{ use async_channel::{Receiver, Sender}; use async_lock::RwLock; -use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; +use futures::{FutureExt, StreamExt, stream::FuturesUnordered}; use memberlist_core::{ + CheapClone, bytes::{BufMut, Bytes, BytesMut}, tracing, transport::{AddressResolver, Id, Node, Transport}, types::{OneOrMore, SmallVec, TinyVec}, - CheapClone, }; use crate::{ diff --git a/core/src/snapshot.rs b/serf-core/src/snapshot.rs similarity index 99% rename from core/src/snapshot.rs rename to serf-core/src/snapshot.rs index 6adcdad..fb11c76 100644 --- a/core/src/snapshot.rs +++ b/serf-core/src/snapshot.rs @@ -15,15 +15,15 @@ use async_channel::{Receiver, Sender}; use byteorder::{LittleEndian, ReadBytesExt}; use futures::FutureExt; use memberlist_core::{ + CheapClone, agnostic_lite::{AsyncSpawner, RuntimeLite}, bytes::{BufMut, BytesMut}, tracing, transport::{AddressResolver, Id, MaybeResolvedAddress, Node, Transport}, types::TinyVec, - CheapClone, }; use rand::seq::SliceRandom; -use serf_types::UserEventMessage; +use serf_proto::UserEventMessage; use crate::{ delegate::{Delegate, TransformDelegate}, @@ -183,9 +183,7 @@ macro_rules! encode { data[1..N].copy_from_slice(&$t.to_le_bytes()); $w.write_all(&data).map(|_| N) }}; - ($w:ident.$ident: ident) => {{ - $w.write_all(&[Self::$ident]).map(|_| 1) - }}; + ($w:ident.$ident: ident) => {{ $w.write_all(&[Self::$ident]).map(|_| 1) }}; } impl SnapshotRecord<'_, I, A> diff --git a/core/src/types.rs b/serf-core/src/types.rs similarity index 99% rename from core/src/types.rs rename to serf-core/src/types.rs index 4c9f4b9..2e5d0b2 100644 --- a/core/src/types.rs +++ b/serf-core/src/types.rs @@ -1,4 +1,4 @@ -pub use serf_types::*; +pub use serf_proto::*; mod member; pub(crate) use member::*; diff --git a/core/src/types/member.rs b/serf-core/src/types/member.rs similarity index 98% rename from core/src/types/member.rs rename to serf-core/src/types/member.rs index 99f6f44..9f4a5f6 100644 --- a/core/src/types/member.rs +++ b/serf-core/src/types/member.rs @@ -1,5 +1,5 @@ use memberlist_core::types::OneOrMore; -use serf_types::Member; +use serf_proto::Member; use std::collections::HashMap; diff --git a/core/src/types/message.rs b/serf-core/src/types/message.rs similarity index 100% rename from core/src/types/message.rs rename to serf-core/src/types/message.rs diff --git a/types/Cargo.toml b/serf-proto/Cargo.toml similarity index 54% rename from types/Cargo.toml rename to serf-proto/Cargo.toml index e92dade..61fc0de 100644 --- a/types/Cargo.toml +++ b/serf-proto/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "serf-types" -version.workspace = true +name = "serf-proto" +version = "0.1.0" rust-version.workspace = true edition.workspace = true repository.workspace = true @@ -9,25 +9,30 @@ license.workspace = true description = "Types for the `serf` crate" [features] -encryption = ["memberlist-types/encryption", "futures"] -serde = ["dep:serde", "indexmap/serde", "memberlist-types/serde", "smol_str/serde", "bitflags/serde"] -metrics = ["memberlist-types/metrics"] +encryption = ["memberlist-proto/encryption", "futures"] +serde = ["dep:serde", "indexmap/serde", "memberlist-proto/serde", "smol_str/serde", "bitflags/serde"] +metrics = ["memberlist-proto/metrics"] + +arbitrary = ["dep:arbitrary", "memberlist-proto/arbitrary", "smol_str/arbitrary"] +quickcheck = ["dep:quickcheck", "memberlist-proto/quickcheck"] [dependencies] bitflags = "2" byteorder.workspace = true bytemuck = { version = "1", features = ["derive"] } -derive_more.workspace = true +derive_more = { workspace = true, features = ["is_variant", "display"] } futures = { workspace = true, optional = true, features = ["alloc"] } indexmap.workspace = true -memberlist-types.workspace = true +memberlist-proto.workspace = true smol_str.workspace = true -transformable = { version = "0.2", features = ["async"] } thiserror.workspace = true viewit.workspace = true serde = { workspace = true, optional = true } +arbitrary = { version = "1", optional = true, default-features = false, features = ["derive"] } +quickcheck = { version = "1", optional = true, default-features = false } + [dev-dependencies] rand.workspace = true futures = { workspace = true, features = ["executor"] } diff --git a/serf-proto/src/arbitrary_impl.rs b/serf-proto/src/arbitrary_impl.rs new file mode 100644 index 0000000..acb21ba --- /dev/null +++ b/serf-proto/src/arbitrary_impl.rs @@ -0,0 +1,73 @@ +use std::{collections::{HashMap, HashSet}, hash::Hash}; + +use super::Filter; +use arbitrary::{Arbitrary, Unstructured}; +use indexmap::{IndexMap, IndexSet}; +use memberlist_proto::TinyVec; + +pub(super) fn into<'a, F, T>(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result +where + F: arbitrary::Arbitrary<'a>, + T: From, +{ + u.arbitrary::().map(Into::into) +} + +pub(super) fn arbitrary_indexmap<'a, K, V>(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result> +where + K: Arbitrary<'a> + Hash + Eq, + V: Arbitrary<'a>, +{ + let map = u.arbitrary::>()?; + Ok(IndexMap::from_iter(map)) +} + +pub(super) fn arbitrary_indexset<'a, K>(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result> +where + K: Arbitrary<'a> + Hash + Eq, +{ + let map = u.arbitrary::>()?; + Ok(IndexSet::from_iter(map)) +} + + +impl<'a, I> Arbitrary<'a> for Filter +where + I: Arbitrary<'a>, +{ + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + let kind = u.arbitrary::()?; + Ok(if kind { + Filter::Id(into::, TinyVec<_>>(u)?) + } else { + Filter::Tag { tag: u.arbitrary()?, expr: u.arbitrary()? } + }) + } +} + +impl<'a> Arbitrary<'a> for super::QueryFlag { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + Ok(if u.arbitrary()? { + Self::NO_BROADCAST + } else { + Self::ACK + }) + } +} + +#[cfg(feature = "encryption")] +impl<'a, I> Arbitrary<'a> for super::KeyResponse +where + I: Arbitrary<'a> + Eq + std::hash::Hash, +{ + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + Ok(Self { + messages: arbitrary_indexmap(u)?, + num_nodes: u.arbitrary()?, + num_resp: u.arbitrary()?, + num_err: u.arbitrary()?, + keys: arbitrary_indexmap(u)?, + primary_keys: arbitrary_indexmap(u)?, + }) + } +} diff --git a/types/src/clock.rs b/serf-proto/src/clock.rs similarity index 80% rename from types/src/clock.rs rename to serf-proto/src/clock.rs index 1b075c0..b1b7143 100644 --- a/types/src/clock.rs +++ b/serf-proto/src/clock.rs @@ -1,14 +1,15 @@ use std::sync::{ - atomic::{AtomicU64, Ordering}, Arc, + atomic::{AtomicU64, Ordering}, }; -use transformable::{utils::*, Transformable}; +use memberlist_proto::{Data, DataRef}; /// A lamport time is a simple u64 that represents a point in time. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", serde(transparent))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[repr(transparent)] pub struct LamportTime(pub(crate) u64); @@ -92,35 +93,27 @@ impl core::ops::Rem for LamportTime { } } -/// Error that can occur when transforming a lamport time -#[derive(thiserror::Error, Debug)] -pub enum LamportTimeTransformError { - /// Encode varint error - #[error(transparent)] - Encode(#[from] InsufficientBuffer), - /// Decode varint error - #[error(transparent)] - Decode(#[from] DecodeVarintError), -} +impl Data for LamportTime { + type Ref<'a> = Self; -impl Transformable for LamportTime { - type Error = LamportTimeTransformError; + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized { + Ok(val) + } - fn encode(&self, dst: &mut [u8]) -> Result { - encode_u64_varint(self.0, dst).map_err(Into::into) - } + fn encoded_len(&self) -> usize { + ::encoded_len(&self.0) + } - fn encoded_len(&self) -> usize { - encoded_u64_varint_len(self.0) - } + fn encode(&self, buf: &mut [u8]) -> Result { + ::encode(&self.0, buf) + } +} - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - decode_u64_varint(src) - .map(|(n, time)| (n, Self(time))) - .map_err(Into::into) +impl<'a> DataRef<'a, LamportTime> for LamportTime { + fn decode(src: &'a [u8]) -> Result<(usize, LamportTime), memberlist_proto::DecodeError> { + >::decode(src).map(|(n, v)| (n, v.into())) } } @@ -178,13 +171,16 @@ impl LamportClock { } } -#[cfg(test)] -impl LamportTime { - pub(crate) fn random() -> Self { - use rand::Rng; - Self(rand::thread_rng().gen_range(0..u64::MAX)) +#[cfg(feature = "quickcheck")] +const _: () = { + use quickcheck::{Arbitrary, Gen}; + + impl Arbitrary for LamportTime { + fn arbitrary(g: &mut Gen) -> Self { + Self(u64::arbitrary(g)) + } } -} +}; #[test] fn test_lamport_clock() { diff --git a/serf-proto/src/filter.rs b/serf-proto/src/filter.rs new file mode 100644 index 0000000..b2c27cb --- /dev/null +++ b/serf-proto/src/filter.rs @@ -0,0 +1,80 @@ +use memberlist_proto::TinyVec; +use smol_str::SmolStr; + +/// The type of filter +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, derive_more::IsVariant, derive_more::Display)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[non_exhaustive] +pub enum FilterType { + /// Filter by node ids + #[display("id")] + Id, + /// Filter by tag + #[display("tag")] + Tag, + /// Unknown filter type + #[display("unknown({_0})")] + Unknown(u8), +} + +impl FilterType { + /// Get the string representation of the filter type + #[inline] + pub fn as_str(&self) -> std::borrow::Cow<'static, str> { + std::borrow::Cow::Borrowed(match self { + Self::Id => "id", + Self::Tag => "tag", + Self::Unknown(val) => return std::borrow::Cow::Owned(format!("unknown({})", val)), + }) + } +} + +impl From for FilterType { + fn from(value: u8) -> Self { + match value { + 0 => Self::Id, + 1 => Self::Tag, + val => Self::Unknown(val), + } + } +} + +impl From for u8 { + fn from(val: FilterType) -> Self { + match val { + FilterType::Id => 0, + FilterType::Tag => 1, + FilterType::Unknown(val) => val, + } + } +} + +/// Used with a queryFilter to specify the type of +/// filter we are sending +#[derive(Debug, Clone, Eq, PartialEq, derive_more::IsVariant)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[non_exhaustive] +pub enum Filter { + /// Filter by node ids + Id(TinyVec), + /// Filter by tag + Tag { + /// The tag to filter by + tag: SmolStr, + /// The expression to filter by + expr: SmolStr, + }, +} + +impl Filter { + /// Returns the type of filter + #[inline] + pub const fn ty(&self) -> FilterType { + match self { + Self::Id(_) => FilterType::Id, + Self::Tag { .. } => FilterType::Tag, + } + } +} + diff --git a/serf-proto/src/join.rs b/serf-proto/src/join.rs new file mode 100644 index 0000000..9cce02d --- /dev/null +++ b/serf-proto/src/join.rs @@ -0,0 +1,46 @@ +use super::LamportTime; + +/// The message broadcasted after we join to +/// associated the node with a lamport clock +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct JoinMessage { + /// The lamport time + #[viewit( + getter(const, attrs(doc = "Returns the lamport time for this message")), + setter( + const, + attrs(doc = "Sets the lamport time for this message (Builder pattern)") + ) + )] + ltime: LamportTime, + /// The id of the node + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the node")), + setter(attrs(doc = "Sets the node (Builder pattern)")) + )] + id: I, +} + +impl JoinMessage { + /// Create a new join message + pub fn new(ltime: LamportTime, id: I) -> Self { + Self { ltime, id } + } + + /// Set the lamport time + #[inline] + pub fn set_ltime(&mut self, ltime: LamportTime) -> &mut Self { + self.ltime = ltime; + self + } + + /// Set the id of the node + #[inline] + pub fn set_id(&mut self, id: I) -> &mut Self { + self.id = id; + self + } +} diff --git a/serf-proto/src/key.rs b/serf-proto/src/key.rs new file mode 100644 index 0000000..836ecce --- /dev/null +++ b/serf-proto/src/key.rs @@ -0,0 +1,153 @@ +use indexmap::IndexMap; +use memberlist_proto::{SecretKey, SecretKeys}; +use smol_str::SmolStr; + +/// KeyRequest is used to contain input parameters which get broadcasted to all +/// nodes as part of a key query operation. +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct KeyRequestMessage { + /// The secret key + #[viewit( + getter(const, attrs(doc = "Returns the secret key")), + setter(const, attrs(doc = "Sets the secret key (Builder pattern)")) + )] + key: Option, +} + +/// Key response message +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[cfg(feature = "encryption")] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct KeyResponseMessage { + /// Indicates true/false if there were errors or not + #[viewit( + getter(const, attrs(doc = "Returns true/false if there were errors or not")), + setter( + const, + attrs(doc = "Sets true/false if there were errors or not (Builder pattern)") + ) + )] + result: bool, + /// Contains error messages or other information + #[viewit( + getter( + const, + style = "ref", + attrs(doc = "Returns the error messages or other information") + ), + setter(attrs(doc = "Sets the error messages or other information (Builder pattern)")) + )] + message: SmolStr, + /// Used in listing queries to relay a list of installed keys + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns a list of installed keys")), + setter(attrs(doc = "Sets the the list of installed keys (Builder pattern)")) + )] + keys: SecretKeys, + /// Used in listing queries to relay the primary key + #[viewit( + getter(const, attrs(doc = "Returns the primary key")), + setter(attrs(doc = "Sets the primary key (Builder pattern)")) + )] + primary_key: Option, +} + +impl KeyResponseMessage { + /// Adds a key to the list of keys + #[inline] + pub fn add_key(&mut self, key: SecretKey) -> &mut Self { + self.keys.push(key); + self + } +} + +/// KeyResponse is used to relay a query for a list of all keys in use. +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Default)] +pub struct KeyResponse { + /// Map of node id to response message + #[viewit( + getter( + const, + style = "ref", + attrs(doc = "Returns the map of node id to response message") + ), + setter(attrs(doc = "Sets the map of node id to response message (Builder pattern)")) + )] + messages: IndexMap, + /// Total nodes memberlist knows of + #[viewit( + getter(const, attrs(doc = "Returns the total nodes memberlist knows of")), + setter( + const, + attrs(doc = "Sets total nodes memberlist knows of (Builder pattern)") + ) + )] + num_nodes: usize, + /// Total responses received + #[viewit( + getter(const, attrs(doc = "Returns the total responses received")), + setter( + const, + attrs(doc = "Sets the total responses received (Builder pattern)") + ) + )] + num_resp: usize, + /// Total errors from request + #[viewit( + getter(const, attrs(doc = "Returns the total errors from request")), + setter( + const, + attrs(doc = "Sets the total errors from request (Builder pattern)") + ) + )] + num_err: usize, + + /// A mapping of the value of the key bytes to the + /// number of nodes that have the key installed. + #[viewit( + getter( + const, + style = "ref", + attrs( + doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed." + ) + ), + setter(attrs( + doc = "Sets a mapping of the value of the key bytes to the number of nodes that have the key installed (Builder pattern)" + )) + )] + keys: IndexMap, + + /// A mapping of the value of the primary + /// key bytes to the number of nodes that have the key installed. + #[viewit( + getter( + const, + style = "ref", + attrs( + doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed." + ) + ), + setter(attrs( + doc = "Sets a mapping of the value of the primary key bytes to the number of nodes that have the key installed. (Builder pattern)" + )) + )] + primary_keys: IndexMap, +} + +/// KeyRequestOptions is used to contain optional parameters for a keyring operation +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct KeyRequestOptions { + /// The number of duplicate query responses to send by relaying through + /// other nodes, for redundancy + pub relay_factor: u8, +} diff --git a/serf-proto/src/leave.rs b/serf-proto/src/leave.rs new file mode 100644 index 0000000..ba3f2e4 --- /dev/null +++ b/serf-proto/src/leave.rs @@ -0,0 +1,33 @@ +use super::LamportTime; + +/// The message broadcasted to signal the intentional to +/// leave. +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct LeaveMessage { + /// The lamport time + #[viewit( + getter(const, attrs(doc = "Returns the lamport time for this message")), + setter( + const, + attrs(doc = "Sets the lamport time for this message (Builder pattern)") + ) + )] + ltime: LamportTime, + /// The id of the node + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the node")), + setter(attrs(doc = "Sets the node (Builder pattern)")) + )] + id: I, + + /// If prune or not + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns if prune or not")), + setter(attrs(doc = "Sets prune or not (Builder pattern)")) + )] + prune: bool, +} + diff --git a/types/src/lib.rs b/serf-proto/src/lib.rs similarity index 69% rename from types/src/lib.rs rename to serf-proto/src/lib.rs index 9352750..80b0bdc 100644 --- a/types/src/lib.rs +++ b/serf-proto/src/lib.rs @@ -6,13 +6,13 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, allow(unused_attributes))] -pub use memberlist_types::{ - DelegateVersion as MemberlistDelegateVersion, Node, NodeAddress, NodeAddressError, NodeId, - NodeIdTransformError, NodeTransformError, ProtocolVersion as MemberlistProtocolVersion, - UnknownDelegateVersion as UnknownMemberlistDelegateVersion, - UnknownProtocolVersion as UnknownMemberlistProtocolVersion, +pub use memberlist_proto::{ + DelegateVersion as MemberlistDelegateVersion, Node, NodeId, HostAddr, ParseDomainError, ParseHostAddrError, Domain, + ProtocolVersion as MemberlistProtocolVersion, }; -pub use transformable::{Encodable, Transformable}; + +#[cfg(feature = "arbitrary")] +mod arbitrary_impl; mod clock; pub use clock::*; diff --git a/serf-proto/src/member.rs b/serf-proto/src/member.rs new file mode 100644 index 0000000..23405ac --- /dev/null +++ b/serf-proto/src/member.rs @@ -0,0 +1,187 @@ +use std::sync::Arc; + +use memberlist_proto::CheapClone; + +use super::{ + DelegateVersion, MemberlistDelegateVersion, MemberlistProtocolVersion, Node, + ProtocolVersion, Tags, +}; + +const MEMBER_STATUS_NONE: u8 = 0; +const MEMBER_STATUS_ALIVE: u8 = 1; +const MEMBER_STATUS_LEAVING: u8 = 2; +const MEMBER_STATUS_LEFT: u8 = 3; +const MEMBER_STATUS_FAILED: u8 = 4; + + +/// The member status. +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, derive_more::IsVariant, derive_more::Display)] +#[repr(u8)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[non_exhaustive] +pub enum MemberStatus { + /// None status + #[display("none")] + #[default] + None, + /// Alive status + #[display("alive")] + Alive, + /// Leaving status + #[display("leaving")] + Leaving, + /// Left status + #[display("left")] + Left, + /// Failed status + #[display("failed")] + Failed, + /// Unknown state (used for forwards and backwards compatibility) + #[display("unknown({_0})")] + Unknown(u8), +} + +impl From for MemberStatus { + fn from(value: u8) -> Self { + match value { + MEMBER_STATUS_NONE => Self::None, + MEMBER_STATUS_ALIVE => Self::Alive, + MEMBER_STATUS_LEAVING => Self::Leaving, + MEMBER_STATUS_LEFT => Self::Left, + MEMBER_STATUS_FAILED => Self::Failed, + val => Self::Unknown(val), + } + } +} + +impl From for u8 { + fn from(val: MemberStatus) -> Self { + match val { + MemberStatus::None => MEMBER_STATUS_NONE, + MemberStatus::Alive => MEMBER_STATUS_ALIVE, + MemberStatus::Leaving => MEMBER_STATUS_LEAVING, + MemberStatus::Left => MEMBER_STATUS_LEFT, + MemberStatus::Failed => MEMBER_STATUS_FAILED, + MemberStatus::Unknown(val) => val, + } + } +} + +impl MemberStatus { + /// Get the string representation of the member status + #[inline] + pub fn as_str(&self) -> std::borrow::Cow<'static, str> { + std::borrow::Cow::Borrowed(match self { + Self::None => "none", + Self::Alive => "alive", + Self::Leaving => "leaving", + Self::Left => "left", + Self::Failed => "failed", + Self::Unknown(val) => return format!("unknown({})", val).into(), + }) + } +} + + +/// A single member of the Serf cluster. +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct Member { + /// The node + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the node")), + setter(attrs(doc = "Sets the node (Builder pattern)")) + )] + node: Node, + /// The tags + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the tags")), + setter(attrs(doc = "Sets the tags (Builder pattern)")) + )] + tags: Arc, + /// The status + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the status")), + setter(attrs(doc = "Sets the status (Builder pattern)")) + )] + status: MemberStatus, + /// The memberlist protocol version + #[viewit( + getter(const, attrs(doc = "Returns the memberlist protocol version")), + setter( + const, + attrs(doc = "Sets the memberlist protocol version (Builder pattern)") + ) + )] + memberlist_protocol_version: MemberlistProtocolVersion, + /// The memberlist delegate version + #[viewit( + getter(const, attrs(doc = "Returns the memberlist delegate version")), + setter( + const, + attrs(doc = "Sets the memberlist delegate version (Builder pattern)") + ) + )] + memberlist_delegate_version: MemberlistDelegateVersion, + + /// The serf protocol version + #[viewit( + getter(const, attrs(doc = "Returns the serf protocol version")), + setter(const, attrs(doc = "Sets the serf protocol version (Builder pattern)")) + )] + protocol_version: ProtocolVersion, + /// The serf delegate version + #[viewit( + getter(const, attrs(doc = "Returns the serf delegate version")), + setter(const, attrs(doc = "Sets the serf delegate version (Builder pattern)")) + )] + delegate_version: DelegateVersion, +} + +impl Member { + /// Create a new member with the given node, tags, and status. + /// Other fields are set to their default values. + #[inline] + pub fn new(node: Node, tags: Tags, status: MemberStatus) -> Self { + Self { + node, + tags: Arc::new(tags), + status, + memberlist_protocol_version: MemberlistProtocolVersion::V1, + memberlist_delegate_version: MemberlistDelegateVersion::V1, + protocol_version: ProtocolVersion::V1, + delegate_version: DelegateVersion::V1, + } + } +} + +impl Clone for Member { + fn clone(&self) -> Self { + Self { + node: self.node.clone(), + tags: self.tags.clone(), + status: self.status, + memberlist_protocol_version: self.memberlist_protocol_version, + memberlist_delegate_version: self.memberlist_delegate_version, + protocol_version: self.protocol_version, + delegate_version: self.delegate_version, + } + } +} + +impl CheapClone for Member { + fn cheap_clone(&self) -> Self { + Self { + node: self.node.cheap_clone(), + tags: self.tags.cheap_clone(), + status: self.status, + memberlist_protocol_version: self.memberlist_protocol_version, + memberlist_delegate_version: self.memberlist_delegate_version, + protocol_version: self.protocol_version, + delegate_version: self.delegate_version, + } + } +} diff --git a/types/src/message.rs b/serf-proto/src/message.rs similarity index 74% rename from types/src/message.rs rename to serf-proto/src/message.rs index be48fd1..f126dc0 100644 --- a/types/src/message.rs +++ b/serf-proto/src/message.rs @@ -1,14 +1,8 @@ use std::sync::Arc; -use crate::{ - JoinMessageTransformError, LeaveMessageTransformError, MemberTransformError, - PushPullMessageTransformError, QueryMessageTransformError, QueryResponseMessageTransformError, - UserEventMessageTransformError, -}; - use super::{ - Encodable, JoinMessage, LeaveMessage, Member, PushPullMessage, PushPullMessageRef, QueryMessage, - QueryResponseMessage, Transformable, UserEventMessage, + JoinMessage, LeaveMessage, Member, PushPullMessage, PushPullMessageRef, QueryMessage, + QueryResponseMessage, UserEventMessage, }; #[cfg(feature = "encryption")] @@ -374,93 +368,3 @@ impl SerfMessage { } } } - -/// Error that can occur when transforming a [`SerfMessage`] or [`SerfMessageRef`] -#[derive(thiserror::Error)] -pub enum SerfMessageTransformError -where - I: Transformable + core::hash::Hash + Eq, - A: Transformable + core::hash::Hash + Eq, -{ - /// [`LeaveMessage`] transformation error - #[error(transparent)] - Leave(#[from] LeaveMessageTransformError), - /// [`JoinMessage`] transformation error - #[error(transparent)] - Join(#[from] JoinMessageTransformError), - /// [`PushPullMessage`] transformation error - #[error(transparent)] - PushPull(#[from] PushPullMessageTransformError), - /// [`UserEventMessage`] transformation error - #[error(transparent)] - UserEvent(#[from] UserEventMessageTransformError), - /// [`QueryMessage`] transformation error - #[error(transparent)] - Query(#[from] QueryMessageTransformError), - /// [`QueryResponseMessage`] transformation error - #[error(transparent)] - QueryResponse(#[from] QueryResponseMessageTransformError), - /// [`Member`] transformation error - #[error(transparent)] - ConflictResponse(#[from] MemberTransformError), - /// [`KeyRequestMessage`] transformation error - #[cfg(feature = "encryption")] - #[error(transparent)] - KeyRequest(#[from] crate::key::OptionSecretKeyTransformError), - /// [`KeyResponseMessage`] transformation error - #[cfg(feature = "encryption")] - #[error(transparent)] - KeyResponse(#[from] crate::key::KeyResponseMessageTransformError), -} - -impl core::fmt::Debug for SerfMessageTransformError -where - I: Transformable + core::hash::Hash + Eq, - A: Transformable + core::hash::Hash + Eq, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -impl Encodable for SerfMessageRef<'_, I, A> -where - I: Transformable + core::hash::Hash + Eq, - A: Transformable + core::hash::Hash + Eq, -{ - type Error = SerfMessageTransformError; - - #[inline] - fn encoded_len(&self) -> usize { - match *self { - Self::Leave(msg) => Transformable::encoded_len(msg), - Self::Join(msg) => Transformable::encoded_len(msg), - Self::PushPull(msg) => Encodable::encoded_len(&msg), - Self::UserEvent(msg) => Transformable::encoded_len(msg), - Self::Query(msg) => Transformable::encoded_len(msg), - Self::QueryResponse(msg) => Transformable::encoded_len(msg), - Self::ConflictResponse(msg) => Transformable::encoded_len(msg), - #[cfg(feature = "encryption")] - Self::KeyRequest(msg) => Transformable::encoded_len(msg), - #[cfg(feature = "encryption")] - Self::KeyResponse(msg) => Transformable::encoded_len(msg), - } - } - - #[inline] - fn encode(&self, dst: &mut [u8]) -> Result { - match *self { - Self::Leave(msg) => Transformable::encode(msg, dst).map_err(Into::into), - Self::Join(msg) => Transformable::encode(msg, dst).map_err(Into::into), - Self::PushPull(msg) => Encodable::encode(&msg, dst).map_err(Into::into), - Self::UserEvent(msg) => Transformable::encode(msg, dst).map_err(Into::into), - Self::Query(msg) => Transformable::encode(msg, dst).map_err(Into::into), - Self::QueryResponse(msg) => Transformable::encode(msg, dst).map_err(Into::into), - Self::ConflictResponse(msg) => Transformable::encode(msg, dst).map_err(Into::into), - #[cfg(feature = "encryption")] - Self::KeyRequest(msg) => Transformable::encode(msg, dst).map_err(Into::into), - #[cfg(feature = "encryption")] - Self::KeyResponse(msg) => Transformable::encode(msg, dst).map_err(Into::into), - } - } -} diff --git a/serf-proto/src/push_pull.rs b/serf-proto/src/push_pull.rs new file mode 100644 index 0000000..6d52f19 --- /dev/null +++ b/serf-proto/src/push_pull.rs @@ -0,0 +1,147 @@ +use indexmap::{IndexMap, IndexSet}; +use memberlist_proto::TinyVec; + + +use super::{LamportTime, UserEvents}; + +/// Used when doing a state exchange. This +/// is a relatively large message, but is sent infrequently +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr( + feature = "serde", + serde(bound( + serialize = "I: core::cmp::Eq + core::hash::Hash + serde::Serialize", + deserialize = "I: core::cmp::Eq + core::hash::Hash + serde::Deserialize<'de>" + )) +)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct PushPullMessage { + /// Current node lamport time + #[viewit( + getter(const, style = "move", attrs(doc = "Returns the lamport time")), + setter(const, attrs(doc = "Sets the lamport time (Builder pattern)")) + )] + ltime: LamportTime, + /// Maps the node to its status time + #[viewit( + getter( + const, + style = "ref", + attrs(doc = "Returns the maps the node to its status time") + ), + setter(attrs(doc = "Sets the maps the node to its status time (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::arbitrary_indexmap))] + status_ltimes: IndexMap, + /// List of left nodes + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the list of left nodes")), + setter(attrs(doc = "Sets the list of left nodes (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::arbitrary_indexset))] + left_members: IndexSet, + /// Lamport time for event clock + #[viewit( + getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for event clock") + ), + setter( + const, + attrs(doc = "Sets the lamport time for event clock (Builder pattern)") + ) + )] + event_ltime: LamportTime, + /// Recent events + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the recent events")), + setter(attrs(doc = "Sets the recent events (Builder pattern)")) + )] + events: TinyVec>, + /// Lamport time for query clock + #[viewit( + getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for query clock") + ), + setter( + const, + attrs(doc = "Sets the lamport time for query clock (Builder pattern)") + ) + )] + query_ltime: LamportTime, +} + +impl PartialEq for PushPullMessage +where + I: core::hash::Hash + Eq, +{ + fn eq(&self, other: &Self) -> bool { + self.ltime == other.ltime + && self.status_ltimes == other.status_ltimes + && self.left_members == other.left_members + && self.event_ltime == other.event_ltime + && self.events == other.events + && self.query_ltime == other.query_ltime + } +} + +/// Used when doing a state exchange. This +/// is a relatively large message, but is sent infrequently +#[viewit::viewit(getters(skip), setters(skip))] +#[derive(Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct PushPullMessageRef<'a, I> { + /// Current node lamport time + ltime: LamportTime, + /// Maps the node to its status time + status_ltimes: &'a IndexMap, + /// List of left nodes + left_members: &'a IndexSet, + /// Lamport time for event clock + event_ltime: LamportTime, + /// Recent events + events: &'a [Option], + /// Lamport time for query clock + query_ltime: LamportTime, +} + +impl Clone for PushPullMessageRef<'_, I> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for PushPullMessageRef<'_, I> {} + +impl<'a, I> From<&'a PushPullMessage> for PushPullMessageRef<'a, I> { + #[inline] + fn from(msg: &'a PushPullMessage) -> Self { + Self { + ltime: msg.ltime, + status_ltimes: &msg.status_ltimes, + left_members: &msg.left_members, + event_ltime: msg.event_ltime, + events: &msg.events, + query_ltime: msg.query_ltime, + } + } +} + +impl<'a, I> From<&'a mut PushPullMessage> for PushPullMessageRef<'a, I> { + #[inline] + fn from(msg: &'a mut PushPullMessage) -> Self { + Self { + ltime: msg.ltime, + status_ltimes: &msg.status_ltimes, + left_members: &msg.left_members, + event_ltime: msg.event_ltime, + events: &msg.events, + query_ltime: msg.query_ltime, + } + } +} diff --git a/serf-proto/src/query.rs b/serf-proto/src/query.rs new file mode 100644 index 0000000..349f71e --- /dev/null +++ b/serf-proto/src/query.rs @@ -0,0 +1,162 @@ +use smol_str::SmolStr; + + +use std::time::Duration; + +use memberlist_proto::{Node, TinyVec, bytes::Bytes}; + +use super::LamportTime; + +bitflags::bitflags! { + /// Flags for query message + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] + #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] + #[cfg_attr(feature = "serde", serde(transparent))] + pub struct QueryFlag: u32 { + /// Ack flag is used to force receiver to send an ack back + const ACK = 1 << 0; + /// NoBroadcast is used to prevent re-broadcast of a query. + /// this can be used to selectively send queries to individual members + const NO_BROADCAST = 1 << 1; + } +} + +/// Query message +#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct QueryMessage { + /// Event lamport time + #[viewit( + getter(const, style = "move", attrs(doc = "Returns the event lamport time")), + setter(const, attrs(doc = "Sets the event lamport time (Builder pattern)")) + )] + ltime: LamportTime, + /// query id, randomly generated + #[viewit( + getter(const, style = "move", attrs(doc = "Returns the query id")), + setter(attrs(doc = "Sets the query id (Builder pattern)")) + )] + id: u32, + /// source node + #[viewit( + getter(const, attrs(doc = "Returns the from node")), + setter(attrs(doc = "Sets the from node (Builder pattern)")) + )] + from: Node, + /// Potential query filters + #[viewit( + getter(const, attrs(doc = "Returns the potential query filters")), + setter(attrs(doc = "Sets the potential query filters (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, TinyVec>))] + filters: TinyVec, + /// Used to provide various flags + #[viewit( + getter(const, style = "move", attrs(doc = "Returns the flags")), + setter(attrs(doc = "Sets the flags (Builder pattern)")) + )] + flags: QueryFlag, + /// Used to set the number of duplicate relayed responses + #[viewit( + getter( + const, + style = "move", + attrs(doc = "Returns the number of duplicate relayed responses") + ), + setter(attrs(doc = "Sets the number of duplicate relayed responses (Builder pattern)")) + )] + relay_factor: u8, + /// Maximum time between delivery and response + #[viewit( + getter( + const, + style = "move", + attrs(doc = "Returns the maximum time between delivery and response") + ), + setter(attrs(doc = "Sets the maximum time between delivery and response (Builder pattern)")) + )] + timeout: Duration, + /// Query nqme + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the name of the query")), + setter(attrs(doc = "Sets the name of the query (Builder pattern)")) + )] + name: SmolStr, + /// Query payload + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the payload")), + setter(attrs(doc = "Sets the payload (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + payload: Bytes, +} + +impl QueryMessage { + /// Checks if the ack flag is set + #[inline] + pub fn ack(&self) -> bool { + self.flags.contains(QueryFlag::ACK) + } + + /// Checks if the no broadcast flag is set + #[inline] + pub fn no_broadcast(&self) -> bool { + self.flags.contains(QueryFlag::NO_BROADCAST) + } +} + +/// Query response message +#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct QueryResponseMessage { + /// Event lamport time + #[viewit( + getter(const, attrs(doc = "Returns the lamport time for this message")), + setter( + const, + attrs(doc = "Sets the lamport time for this message (Builder pattern)") + ) + )] + ltime: LamportTime, + /// query id + #[viewit( + getter(const, attrs(doc = "Returns the query id")), + setter(attrs(doc = "Sets the query id (Builder pattern)")) + )] + id: u32, + /// node + #[viewit( + getter(const, attrs(doc = "Returns the from node")), + setter(attrs(doc = "Sets the from node (Builder pattern)")) + )] + from: Node, + /// Used to provide various flags + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the flags")), + setter(attrs(doc = "Sets the flags (Builder pattern)")) + )] + flags: QueryFlag, + /// Optional response payload + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the payload")), + setter(attrs(doc = "Sets the payload (Builder pattern)")) + )] + payload: Bytes, +} + +impl QueryResponseMessage { + /// Checks if the ack flag is set + #[inline] + pub fn ack(&self) -> bool { + self.flags.contains(QueryFlag::ACK) + } + + /// Checks if the no broadcast flag is set + #[inline] + pub fn no_broadcast(&self) -> bool { + self.flags.contains(QueryFlag::NO_BROADCAST) + } +} diff --git a/serf-proto/src/tags.rs b/serf-proto/src/tags.rs new file mode 100644 index 0000000..03822dd --- /dev/null +++ b/serf-proto/src/tags.rs @@ -0,0 +1,58 @@ +use indexmap::IndexMap; +use smol_str::SmolStr; + + +/// Tags of a node +#[derive( + Debug, + Default, + PartialEq, + Clone, + derive_more::From, + derive_more::Into, + derive_more::Deref, + derive_more::DerefMut, +)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct Tags(#[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::arbitrary_indexmap))] IndexMap); + +impl IntoIterator for Tags { + type Item = (SmolStr, SmolStr); + type IntoIter = indexmap::map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl FromIterator<(SmolStr, SmolStr)> for Tags { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl<'a> FromIterator<(&'a str, &'a str)> for Tags { + fn from_iter>(iter: T) -> Self { + Self( + iter + .into_iter() + .map(|(k, v)| (SmolStr::new(k), SmolStr::new(v))) + .collect(), + ) + } +} + +impl Tags { + /// Create a new Tags + #[inline] + pub fn new() -> Self { + Self(IndexMap::new()) + } + + /// Create a new Tags with a capacity + pub fn with_capacity(cap: usize) -> Self { + Self(IndexMap::with_capacity(cap)) + } +} diff --git a/serf-proto/src/user_event.rs b/serf-proto/src/user_event.rs new file mode 100644 index 0000000..37eb2cd --- /dev/null +++ b/serf-proto/src/user_event.rs @@ -0,0 +1,109 @@ +use memberlist_proto::{bytes::Bytes, CheapClone, OneOrMore}; +use smol_str::SmolStr; + +use super::LamportTime; + +/// Used to buffer events to prevent re-delivery +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct UserEvents { + /// The lamport time + #[viewit( + getter(const, attrs(doc = "Returns the lamport time for this message")), + setter( + const, + attrs(doc = "Sets the lamport time for this message (Builder pattern)") + ) + )] + ltime: LamportTime, + + /// The user events + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the user events")), + setter(attrs(doc = "Sets the user events (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, OneOrMore>))] + events: OneOrMore, +} + +/// Stores all the user events at a specific time +#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct UserEvent { + /// The name of the event + #[viewit( + getter(const, attrs(doc = "Returns the name of the event")), + setter(attrs(doc = "Sets the name of the event (Builder pattern)")) + )] + name: SmolStr, + /// The payload of the event + #[viewit( + getter(const, attrs(doc = "Returns the payload of the event")), + setter(attrs(doc = "Sets the payload of the event (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + payload: Bytes, +} + +/// Used for user-generated events +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Default, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct UserEventMessage { + /// The lamport time + #[viewit( + getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for this message") + ), + setter( + const, + attrs(doc = "Sets the lamport time for this message (Builder pattern)") + ) + )] + ltime: LamportTime, + /// The name of the event + #[viewit( + getter(const, attrs(doc = "Returns the name of the event")), + setter(attrs(doc = "Sets the name of the event (Builder pattern)")) + )] + name: SmolStr, + /// The payload of the event + #[viewit( + getter(const, attrs(doc = "Returns the payload of the event")), + setter(attrs(doc = "Sets the payload of the event (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + payload: Bytes, + /// "Can Coalesce". + #[viewit( + getter( + const, + style = "move", + attrs(doc = "Returns if this message can be coalesced") + ), + setter( + const, + attrs(doc = "Sets if this message can be coalesced (Builder pattern)") + ) + )] + cc: bool, +} + +impl CheapClone for UserEventMessage { + fn cheap_clone(&self) -> Self { + Self { + ltime: self.ltime, + name: self.name.cheap_clone(), + payload: self.payload.clone(), + cc: self.cc, + } + } +} + diff --git a/serf-proto/src/version.rs b/serf-proto/src/version.rs new file mode 100644 index 0000000..b8f6123 --- /dev/null +++ b/serf-proto/src/version.rs @@ -0,0 +1,149 @@ +use memberlist_proto::{Data, DataRef, DecodeError, EncodeError, WireType}; + +/// Delegate version +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display)] +#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[non_exhaustive] +pub enum DelegateVersion { + /// Version 1 + #[default] + #[display("v1")] + V1, + /// Unknown version (used for forwards and backwards compatibility) + #[display("unknown({_0})")] + Unknown(u8), +} + +impl From for DelegateVersion { + fn from(v: u8) -> Self { + match v { + 1 => Self::V1, + val => Self::Unknown(val), + } + } +} + +impl From for u8 { + fn from(v: DelegateVersion) -> Self { + match v { + DelegateVersion::V1 => 1, + DelegateVersion::Unknown(val) => val, + } + } +} + +/// Protocol version +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display)] +#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[non_exhaustive] +pub enum ProtocolVersion { + /// Version 1 + #[default] + #[display("v1")] + V1, + /// Unknown version (used for forwards and backwards compatibility) + #[display("unknown({_0})")] + Unknown(u8), +} + +impl From for ProtocolVersion { + fn from(v: u8) -> Self { + match v { + 1 => Self::V1, + val => Self::Unknown(val), + } + } +} + +impl From for u8 { + fn from(v: ProtocolVersion) -> Self { + match v { + ProtocolVersion::V1 => 1, + ProtocolVersion::Unknown(val) => val, + } + } +} + +macro_rules! impl_data { + ($($ty:ty),+$(,)?) => { + $( + impl<'a> DataRef<'a, Self> for $ty { + fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> { + if src.is_empty() { + return Err(DecodeError::buffer_underflow()); + } + + Ok((1, Self::from(src[0]))) + } + } + + impl Data for $ty { + const WIRE_TYPE: WireType = WireType::Byte; + + type Ref<'a> = Self; + + fn from_ref(val: Self::Ref<'_>) -> Result { + Ok(val) + } + + #[inline] + fn encoded_len(&self) -> usize { + 1 + } + + #[inline] + fn encode(&self, buf: &mut [u8]) -> Result { + if buf.is_empty() { + return Err(EncodeError::insufficient_buffer(1, 0)); + } + + buf[0] = u8::from(*self); + Ok(1) + } + } + )* + }; +} + +impl_data!(DelegateVersion, ProtocolVersion); + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "arbitrary")] + use arbitrary::{Arbitrary, Unstructured}; + + #[test] + #[cfg(feature = "arbitrary")] + fn test_delegate_version() { + let mut buf = [0; 64]; + rand::fill(&mut buf[..]); + + let mut data = Unstructured::new(&buf); + let _ = DelegateVersion::arbitrary(&mut data).unwrap(); + + assert_eq!(u8::from(DelegateVersion::V1), 1u8); + assert_eq!(DelegateVersion::V1.to_string(), "V1"); + assert_eq!(DelegateVersion::Unknown(2).to_string(), "Unknown(2)"); + assert_eq!(DelegateVersion::from(1), DelegateVersion::V1); + assert_eq!(DelegateVersion::from(2), DelegateVersion::Unknown(2)); + } + + #[test] + #[cfg(feature = "arbitrary")] + fn test_protocol_version() { + let mut buf = [0; 64]; + rand::fill(&mut buf[..]); + + let mut data = Unstructured::new(&buf); + let _ = ProtocolVersion::arbitrary(&mut data).unwrap(); + assert_eq!(u8::from(ProtocolVersion::V1), 1); + assert_eq!(ProtocolVersion::V1.to_string(), "V1"); + assert_eq!(ProtocolVersion::Unknown(2).to_string(), "Unknown(2)"); + assert_eq!(ProtocolVersion::from(1), ProtocolVersion::V1); + assert_eq!(ProtocolVersion::from(2), ProtocolVersion::Unknown(2)); + } +} diff --git a/serf/Cargo.toml b/serf/Cargo.toml index 17acbb8..57e9ff7 100644 --- a/serf/Cargo.toml +++ b/serf/Cargo.toml @@ -32,18 +32,16 @@ metrics = [ "serf-core/metrics", ] -compression = ["memberlist/compression"] +compression = [] encryption = ["memberlist/encryption", "serf-core/encryption"] quic = ["memberlist/quic"] quinn = ["memberlist/quinn", "quic"] -s2n = ["memberlist/s2n", "quic"] net = ["memberlist/net"] tcp = ["net"] tls = ["memberlist/tls", "net"] -native-tls = ["memberlist/native-tls", "net"] # enable DNS node address resolver dns = ["memberlist/dns"] diff --git a/serf/test/main.rs b/serf/test/main.rs index ad527a9..c8ed71e 100644 --- a/serf/test/main.rs +++ b/serf/test/main.rs @@ -17,12 +17,12 @@ fn tokio_run(fut: impl Future) { #[cfg(feature = "smol")] fn smol_run(fut: impl Future) { - use serf::agnostic::{smol::SmolRuntime, RuntimeLite}; + use serf::agnostic::{RuntimeLite, smol::SmolRuntime}; run_unit_test(SmolRuntime::block_on, fut); } #[cfg(feature = "async-std")] fn async_std_run(fut: impl Future) { - use serf::agnostic::{async_std::AsyncStdRuntime, RuntimeLite}; + use serf::agnostic::{RuntimeLite, async_std::AsyncStdRuntime}; run_unit_test(AsyncStdRuntime::block_on, fut); } diff --git a/types/src/filter.rs b/types/src/filter.rs deleted file mode 100644 index 2075fce..0000000 --- a/types/src/filter.rs +++ /dev/null @@ -1,298 +0,0 @@ -use byteorder::{ByteOrder, NetworkEndian}; -use memberlist_types::TinyVec; -use smol_str::SmolStr; -use transformable::StringTransformError; - -use super::Transformable; - -/// Unknown filter type error -#[derive(Debug, thiserror::Error)] -#[error("unknown filter type: {0}")] -pub struct UnknownFilterType(u8); - -/// The type of filter -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[repr(u8)] -#[non_exhaustive] -pub enum FilterType { - /// Filter by node ids - Id = 0, - /// Filter by tag - Tag = 1, -} - -impl FilterType { - /// Get the string representation of the filter type - #[inline] - pub const fn as_str(&self) -> &'static str { - match self { - Self::Id => "id", - Self::Tag => "tag", - } - } -} - -impl TryFrom for FilterType { - type Error = UnknownFilterType; - - fn try_from(value: u8) -> Result { - match value { - 0 => Ok(Self::Id), - 1 => Ok(Self::Tag), - other => Err(UnknownFilterType(other)), - } - } -} - -/// Transform error type for [`Filter`] -#[derive(thiserror::Error)] -pub enum FilterTransformError { - /// Returned when there are not enough bytes to decode - #[error("not enough bytes to decode")] - NotEnoughBytes(usize), - /// Returned when the buffer is too small to encode - #[error("encode buffer too small")] - BufferTooSmall, - /// Returned when there is an error decoding a node - #[error(transparent)] - Id(I::Error), - /// Returned when there is an error decoding a tag - #[error(transparent)] - Tag(#[from] StringTransformError), - /// Returned when there is an error decoding - #[error("not enough nodes, expected {expected} nodes, got {got} nodes")] - NotEnoughIds { - /// expected number of nodes - expected: usize, - /// got number of nodes - got: usize, - }, - /// Returned when there is an unknown filter type - #[error("unknown filter type: {0}")] - UnknownFilterType(#[from] UnknownFilterType), -} - -impl core::fmt::Debug for FilterTransformError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -/// Used with a queryFilter to specify the type of -/// filter we are sending -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum Filter { - /// Filter by node ids - Id(TinyVec), - /// Filter by tag - Tag { - /// The tag to filter by - tag: SmolStr, - /// The expression to filter by - expr: SmolStr, - }, -} - -impl Filter { - /// Returns the type of filter - #[inline] - pub const fn ty(&self) -> FilterType { - match self { - Self::Id(_) => FilterType::Id, - Self::Tag { .. } => FilterType::Tag, - } - } -} - -impl Transformable for Filter -where - I: Transformable, -{ - type Error = FilterTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let ty = self.ty(); - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32); - offset += 4; - match self { - Self::Id(nodes) => { - dst[offset] = ty as u8; - offset += 1; - let len = nodes.len() as u32; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], len); - offset += 4; - for node in nodes.iter() { - offset += node.encode(&mut dst[offset..]).map_err(Self::Error::Id)?; - } - Ok(offset) - } - Self::Tag { tag, expr } => { - dst[offset] = ty as u8; - offset += 1; - offset += tag.encode(&mut dst[offset..])?; - offset += expr.encode(&mut dst[offset..])?; - Ok(offset) - } - } - } - - fn encoded_len(&self) -> usize { - 4 + match self { - Self::Id(nodes) => 1 + 4 + nodes.iter().map(Transformable::encoded_len).sum::(), - Self::Tag { tag, expr } => 1 + tag.encoded_len() + expr.encoded_len(), - } - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src_len < 5 { - return Err(Self::Error::NotEnoughBytes(5)); - } - - let len = NetworkEndian::read_u32(&src[0..4]) as usize; - if src_len < len { - return Err(Self::Error::NotEnoughBytes(len)); - } - - let ty = FilterType::try_from(src[4])?; - let mut offset = 5; - match ty { - FilterType::Id => { - let total_nodes = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize; - offset += 4; - let mut nodes = TinyVec::with_capacity(total_nodes); - for _ in 0..total_nodes { - let (n, node) = I::decode(&src[offset..]).map_err(Self::Error::Id)?; - nodes.push(node); - offset += n; - } - - debug_assert_eq!( - len, offset, - "expected read {} bytes, but actual read {} bytes", - len, offset - ); - - Ok((offset, Self::Id(nodes))) - } - FilterType::Tag => { - let (n, tag) = SmolStr::decode(&src[offset..])?; - offset += n; - let (n, expr) = SmolStr::decode(&src[offset..])?; - offset += n; - - debug_assert_eq!( - len, offset, - "expected read {} bytes, but actual read {} bytes", - len, offset - ); - - Ok((offset, Self::Tag { tag, expr })) - } - } - } -} - -#[cfg(test)] -mod tests { - use rand::{distributions::Alphanumeric, thread_rng, Rng}; - - use super::*; - - impl Filter { - fn random_node(size: usize, num_nodes: usize) -> Self { - let mut nodes = TinyVec::with_capacity(num_nodes); - - for _ in 0..num_nodes { - let id = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let id = String::from_utf8(id).unwrap().into(); - nodes.push(id); - } - Self::Id(nodes) - } - - fn random_tag(size: usize) -> Self { - let rng = rand::thread_rng(); - let tag = rng - .sample_iter(&Alphanumeric) - .take(size) - .collect::>(); - let tag = String::from_utf8(tag).unwrap(); - let rng = rand::thread_rng(); - let expr = rng - .sample_iter(&Alphanumeric) - .take(size) - .collect::>(); - let expr = String::from_utf8(expr).unwrap(); - Self::Tag { - tag: tag.into(), - expr: expr.into(), - } - } - } - - #[test] - fn test_transfrom_encode_decode() { - futures::executor::block_on(async { - for i in 0..100 { - let filter = Filter::random_tag(i); - let mut buf = vec![0; filter.encoded_len()]; - let encoded_len = filter.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, filter.encoded_len()); - - let (decoded_len, decoded) = Filter::::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - Filter::::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - Filter::::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - } - - for i in 0..100 { - let filter = Filter::random_node(i, i % 10); - let mut buf = vec![0; filter.encoded_len()]; - let encoded_len = filter.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, filter.encoded_len()); - - let (decoded_len, decoded) = Filter::::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - Filter::::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - Filter::::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - } - }); - } -} diff --git a/types/src/join.rs b/types/src/join.rs deleted file mode 100644 index d32de1e..0000000 --- a/types/src/join.rs +++ /dev/null @@ -1,188 +0,0 @@ -use byteorder::{ByteOrder, NetworkEndian}; -use transformable::utils::encoded_u64_varint_len; - -use crate::LamportTimeTransformError; - -use super::{LamportTime, Transformable}; - -/// The message broadcasted after we join to -/// associated the node with a lamport clock -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct JoinMessage { - /// The lamport time - #[viewit( - getter(const, attrs(doc = "Returns the lamport time for this message")), - setter( - const, - attrs(doc = "Sets the lamport time for this message (Builder pattern)") - ) - )] - ltime: LamportTime, - /// The id of the node - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the node")), - setter(attrs(doc = "Sets the node (Builder pattern)")) - )] - id: I, -} - -impl JoinMessage { - /// Create a new join message - pub fn new(ltime: LamportTime, id: I) -> Self { - Self { ltime, id } - } - - /// Set the lamport time - #[inline] - pub fn set_ltime(&mut self, ltime: LamportTime) -> &mut Self { - self.ltime = ltime; - self - } - - /// Set the id of the node - #[inline] - pub fn set_id(&mut self, id: I) -> &mut Self { - self.id = id; - self - } -} - -/// Error that can occur when transforming a JoinMessage -#[derive(thiserror::Error)] -pub enum JoinMessageTransformError { - /// Not enough bytes to decode JoinMessage - #[error("not enough bytes to decode JoinMessage")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - EncodeBufferTooSmall, - /// Error transforming Id - #[error(transparent)] - Id(I::Error), - - /// Error transforming LamportTime - #[error(transparent)] - LamportTime(#[from] LamportTimeTransformError), -} - -impl core::fmt::Debug for JoinMessageTransformError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -impl Transformable for JoinMessage -where - I: Transformable, -{ - type Error = JoinMessageTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::EncodeBufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32); - offset += 4; - - offset += self.ltime.encode(&mut dst[offset..])?; - offset += self - .id - .encode(&mut dst[offset..]) - .map_err(Self::Error::Id)?; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actual write {} bytes", - encoded_len, offset - ); - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + encoded_u64_varint_len(self.ltime.0) + self.id.encoded_len() - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - if src.len() < 4 { - return Err(Self::Error::NotEnoughBytes); - } - - let encoded_len = NetworkEndian::read_u32(&src[..4]) as usize; - if src.len() < encoded_len { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 4; - let (n, ltime) = LamportTime::decode(&src[offset..])?; - offset += n; - - let (n, id) = I::decode(&src[offset..]).map_err(Self::Error::Id)?; - offset += n; - - debug_assert_eq!( - offset, encoded_len, - "expect read {} bytes, but actual read {} bytes", - encoded_len, offset - ); - Ok((encoded_len, Self { ltime, id })) - } -} - -#[cfg(test)] -mod tests { - use rand::{distributions::Alphanumeric, thread_rng, Rng}; - use smol_str::SmolStr; - - use super::*; - - impl JoinMessage { - fn random(size: usize) -> Self { - let id = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let id = String::from_utf8(id).unwrap().into(); - - Self { - ltime: LamportTime::random(), - id, - } - } - } - - #[test] - fn test_transfrom_encode_decode() { - futures::executor::block_on(async { - for i in 0..100 { - let filter = JoinMessage::random(i); - let mut buf = vec![0; filter.encoded_len()]; - let encoded_len = filter.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, filter.encoded_len()); - - let (decoded_len, decoded) = JoinMessage::::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - JoinMessage::::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - JoinMessage::::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - } - }); - } -} diff --git a/types/src/key.rs b/types/src/key.rs deleted file mode 100644 index 79aa2c9..0000000 --- a/types/src/key.rs +++ /dev/null @@ -1,617 +0,0 @@ -use byteorder::{ByteOrder, NetworkEndian}; -use indexmap::IndexMap; -use memberlist_types::{SecretKey, SecretKeyTransformError, SecretKeys, SecretKeysTransformError}; -use smol_str::SmolStr; -use transformable::{StringTransformError, Transformable}; - -/// KeyRequest is used to contain input parameters which get broadcasted to all -/// nodes as part of a key query operation. -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(transparent)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct KeyRequestMessage { - /// The secret key - #[viewit( - getter(const, attrs(doc = "Returns the secret key")), - setter(const, attrs(doc = "Sets the secret key (Builder pattern)")) - )] - key: Option, -} - -/// The error that can occur when transforming a [`KeyRequestMessage`] -#[derive(Debug, thiserror::Error)] -pub enum OptionSecretKeyTransformError { - /// Not enough bytes to decode [`Option`] - #[error("not enough bytes to decode")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - /// Error transforming a secret key - #[error(transparent)] - SecretKey(#[from] SecretKeyTransformError), -} - -impl Transformable for KeyRequestMessage { - type Error = OptionSecretKeyTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - match &self.key { - None => { - dst[0] = 0; - Ok(1) - } - Some(key) => key.encode(dst).map_err(Self::Error::SecretKey), - } - } - - fn encoded_len(&self) -> usize { - match &self.key { - Some(key) => key.encoded_len(), - None => 1, - } - } - - fn encode_to_writer(&self, writer: &mut W) -> std::io::Result { - match &self.key { - None => { - writer.write_all(&[0])?; - Ok(1) - } - Some(key) => key.encode_to_writer(writer), - } - } - - async fn encode_to_async_writer( - &self, - writer: &mut W, - ) -> std::io::Result { - use futures::AsyncWriteExt; - - match &self.key { - None => { - writer.write_all(&[0]).await?; - Ok(1) - } - Some(key) => key.encode_to_async_writer(writer).await, - } - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - if src.is_empty() { - return Err(Self::Error::NotEnoughBytes); - } - - match src[0] { - 0 => Ok((1, Self { key: None })), - _ => { - let (n, key) = SecretKey::decode(src).map_err(Self::Error::SecretKey)?; - Ok((n, Self { key: Some(key) })) - } - } - } - - fn decode_from_reader(reader: &mut R) -> std::io::Result<(usize, Self)> - where - Self: Sized, - { - let mut buf = [0u8; 1]; - reader.read_exact(&mut buf)?; - - match buf[0] { - 0 => Ok((1, Self { key: None })), - 16 => { - let mut buf = [0u8; 16]; - reader.read_exact(&mut buf)?; - Ok(( - 17, - Self { - key: Some(SecretKey::from(buf)), - }, - )) - } - 24 => { - let mut buf = [0u8; 24]; - reader.read_exact(&mut buf)?; - Ok(( - 25, - Self { - key: Some(SecretKey::from(buf)), - }, - )) - } - 32 => { - let mut buf = [0u8; 32]; - reader.read_exact(&mut buf)?; - Ok(( - 33, - Self { - key: Some(SecretKey::from(buf)), - }, - )) - } - _ => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "unknown secret key kind", - )), - } - } - - async fn decode_from_async_reader( - reader: &mut R, - ) -> std::io::Result<(usize, Self)> - where - Self: Sized, - { - use futures::AsyncReadExt; - - let mut buf = [0u8; 1]; - reader.read_exact(&mut buf).await?; - - match buf[0] { - 0 => Ok((1, Self { key: None })), - 16 => { - let mut buf = [0u8; 16]; - reader.read_exact(&mut buf).await?; - Ok(( - 17, - Self { - key: Some(SecretKey::from(buf)), - }, - )) - } - 24 => { - let mut buf = [0u8; 24]; - reader.read_exact(&mut buf).await?; - Ok(( - 25, - Self { - key: Some(SecretKey::from(buf)), - }, - )) - } - 32 => { - let mut buf = [0u8; 32]; - reader.read_exact(&mut buf).await?; - Ok(( - 33, - Self { - key: Some(SecretKey::from(buf)), - }, - )) - } - _ => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "unknown secret key kind", - )), - } - } -} - -/// Key response message -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] -#[cfg(feature = "encryption")] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct KeyResponseMessage { - /// Indicates true/false if there were errors or not - #[viewit( - getter(const, attrs(doc = "Returns true/false if there were errors or not")), - setter( - const, - attrs(doc = "Sets true/false if there were errors or not (Builder pattern)") - ) - )] - result: bool, - /// Contains error messages or other information - #[viewit( - getter( - const, - style = "ref", - attrs(doc = "Returns the error messages or other information") - ), - setter(attrs(doc = "Sets the error messages or other information (Builder pattern)")) - )] - message: SmolStr, - /// Used in listing queries to relay a list of installed keys - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns a list of installed keys")), - setter(attrs(doc = "Sets the the list of installed keys (Builder pattern)")) - )] - keys: SecretKeys, - /// Used in listing queries to relay the primary key - #[viewit( - getter(const, attrs(doc = "Returns the primary key")), - setter(attrs(doc = "Sets the primary key (Builder pattern)")) - )] - primary_key: Option, -} - -impl KeyResponseMessage { - /// Adds a key to the list of keys - #[inline] - pub fn add_key(&mut self, key: SecretKey) -> &mut Self { - self.keys.push(key); - self - } -} - -/// Error that can occur when transforming a [`KeyResponseMessage`]. -#[derive(Debug, thiserror::Error)] -pub enum KeyResponseMessageTransformError { - /// Not enough bytes to decode KeyResponseMessage - #[error("not enough bytes to decode `KeyResponseMessage`")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - /// Error transforming a message field - #[error(transparent)] - Message(#[from] StringTransformError), - /// Error transforming a `primary_key` field - #[error(transparent)] - PrimaryKey(#[from] OptionSecretKeyTransformError), - /// Error transforming a `keys` field - #[error(transparent)] - Keys(#[from] SecretKeysTransformError), -} - -impl Transformable for KeyResponseMessage { - type Error = KeyResponseMessageTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32); - offset += 4; - dst[offset] = self.result as u8; - offset += 1; - offset += self.message.encode(&mut dst[offset..])?; - offset += self.keys.encode(&mut dst[offset..])?; - offset += KeyRequestMessage { - key: self.primary_key, - } - .encode(&mut dst[offset..])?; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actual write {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + 1 - + self.message.encoded_len() - + self.keys.encoded_len() - + KeyRequestMessage { - key: self.primary_key, - } - .encoded_len() - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src_len < 5 { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 0; - let encoded_len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize; - if src_len < encoded_len { - return Err(Self::Error::NotEnoughBytes); - } - offset += 4; - - let result = src[offset] != 0; - offset += 1; - let (n, message) = SmolStr::decode(&src[offset..])?; - offset += n; - let (n, keys) = SecretKeys::decode(&src[offset..])?; - offset += n; - let (n, primary_key) = KeyRequestMessage::decode(&src[offset..])?; - offset += n; - - debug_assert_eq!( - offset, encoded_len, - "expect read {} bytes, but actual read {} bytes", - encoded_len, offset - ); - - Ok(( - offset, - Self { - result, - message, - keys, - primary_key: primary_key.key, - }, - )) - } -} - -/// KeyResponse is used to relay a query for a list of all keys in use. -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Default)] -pub struct KeyResponse { - /// Map of node id to response message - #[viewit( - getter( - const, - style = "ref", - attrs(doc = "Returns the map of node id to response message") - ), - setter(attrs(doc = "Sets the map of node id to response message (Builder pattern)")) - )] - messages: IndexMap, - /// Total nodes memberlist knows of - #[viewit( - getter(const, attrs(doc = "Returns the total nodes memberlist knows of")), - setter( - const, - attrs(doc = "Sets total nodes memberlist knows of (Builder pattern)") - ) - )] - num_nodes: usize, - /// Total responses received - #[viewit( - getter(const, attrs(doc = "Returns the total responses received")), - setter( - const, - attrs(doc = "Sets the total responses received (Builder pattern)") - ) - )] - num_resp: usize, - /// Total errors from request - #[viewit( - getter(const, attrs(doc = "Returns the total errors from request")), - setter( - const, - attrs(doc = "Sets the total errors from request (Builder pattern)") - ) - )] - num_err: usize, - - /// A mapping of the value of the key bytes to the - /// number of nodes that have the key installed. - #[viewit( - getter( - const, - style = "ref", - attrs( - doc = "Returns a mapping of the value of the key bytes to the number of nodes that have the key installed." - ) - ), - setter(attrs( - doc = "Sets a mapping of the value of the key bytes to the number of nodes that have the key installed (Builder pattern)" - )) - )] - keys: IndexMap, - - /// A mapping of the value of the primary - /// key bytes to the number of nodes that have the key installed. - #[viewit( - getter( - const, - style = "ref", - attrs( - doc = "Returns a mapping of the value of the primary key bytes to the number of nodes that have the key installed." - ) - ), - setter(attrs( - doc = "Sets a mapping of the value of the primary key bytes to the number of nodes that have the key installed. (Builder pattern)" - )) - )] - primary_keys: IndexMap, -} - -/// KeyRequestOptions is used to contain optional parameters for a keyring operation -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct KeyRequestOptions { - /// The number of duplicate query responses to send by relaying through - /// other nodes, for redundancy - pub relay_factor: u8, -} - -#[cfg(test)] -mod tests { - use rand::{distributions::Alphanumeric, thread_rng, Rng}; - - use super::*; - - impl KeyRequestMessage { - pub(crate) fn random(kind: u8) -> Self { - let key = if rand::random() { - match kind { - 16 => { - let mut buf = [0u8; 16]; - rand::thread_rng().fill(&mut buf); - Some(SecretKey::from(buf)) - } - 24 => { - let mut buf = [0u8; 24]; - rand::thread_rng().fill(&mut buf); - Some(SecretKey::from(buf)) - } - 32 => { - let mut buf = [0u8; 32]; - rand::thread_rng().fill(&mut buf); - Some(SecretKey::from(buf)) - } - _ => None, - } - } else { - None - }; - - Self { key } - } - } - - impl KeyResponseMessage { - pub(crate) fn random(num_keys: usize, size: usize) -> Self { - let mut keys = SecretKeys::new(); - for i in 0..num_keys { - let kind = match i % 3 { - 0 => 16, - 1 => 24, - 2 => 32, - _ => unreachable!(), - }; - let key = match kind { - 16 => { - let mut buf = [0u8; 16]; - rand::thread_rng().fill(&mut buf); - SecretKey::from(buf) - } - 24 => { - let mut buf = [0u8; 24]; - rand::thread_rng().fill(&mut buf); - SecretKey::from(buf) - } - 32 => { - let mut buf = [0u8; 32]; - rand::thread_rng().fill(&mut buf); - SecretKey::from(buf) - } - _ => unreachable!(), - }; - keys.push(key); - } - - let primary_key = if rand::random() { - let mut buf = [0u8; 32]; - rand::thread_rng().fill(&mut buf); - Some(SecretKey::from(buf)) - } else { - None - }; - - let message = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let message = String::from_utf8(message).unwrap().into(); - - Self { - result: rand::random(), - message, - keys, - primary_key, - } - } - } - - #[test] - fn test_key_request_message_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let kind = match i % 4 { - 0 => 0, - 1 => 16, - 2 => 24, - _ => 32, - }; - let key = KeyRequestMessage::random(kind); - let mut buf = vec![0; key.encoded_len()]; - let encoded_len = key.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, key.encoded_len()); - let mut buf1 = vec![]; - let encoded_len1 = key.encode_to_writer(&mut buf1).unwrap(); - assert_eq!(encoded_len1, key.encoded_len()); - let mut buf2 = vec![]; - let encoded_len2 = key.encode_to_async_writer(&mut buf2).await.unwrap(); - assert_eq!(encoded_len2, key.encoded_len()); - - let (decoded_len, decoded) = KeyRequestMessage::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - let (decoded_len, decoded) = KeyRequestMessage::decode(&buf1).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - let (decoded_len, decoded) = KeyRequestMessage::decode(&buf2).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - - let (decoded_len, decoded) = - KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - let (decoded_len, decoded) = - KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf1)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - let (decoded_len, decoded) = - KeyRequestMessage::decode_from_reader(&mut std::io::Cursor::new(&buf2)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - - let (decoded_len, decoded) = - KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - let (decoded_len, decoded) = - KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf1)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - let (decoded_len, decoded) = - KeyRequestMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf2)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, key); - } - }); - } - - #[test] - fn test_key_response_message_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let message = KeyResponseMessage::random(i % 10, i); - let mut buf = vec![0; message.encoded_len()]; - let encoded_len = message.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, message.encoded_len()); - - let (decoded_len, decoded) = KeyResponseMessage::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, message); - - let (decoded_len, decoded) = - KeyResponseMessage::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, message); - - let (decoded_len, decoded) = - KeyResponseMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, message); - } - }); - } -} diff --git a/types/src/leave.rs b/types/src/leave.rs deleted file mode 100644 index 0a9075b..0000000 --- a/types/src/leave.rs +++ /dev/null @@ -1,178 +0,0 @@ -use byteorder::{ByteOrder, NetworkEndian}; - -use super::{LamportTime, LamportTimeTransformError, Transformable}; - -/// The message broadcasted to signal the intentional to -/// leave. -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct LeaveMessage { - /// The lamport time - #[viewit( - getter(const, attrs(doc = "Returns the lamport time for this message")), - setter( - const, - attrs(doc = "Sets the lamport time for this message (Builder pattern)") - ) - )] - ltime: LamportTime, - /// The id of the node - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the node")), - setter(attrs(doc = "Sets the node (Builder pattern)")) - )] - id: I, - - /// If prune or not - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns if prune or not")), - setter(attrs(doc = "Sets prune or not (Builder pattern)")) - )] - prune: bool, -} - -/// Error that can occur when transforming a [`LeaveMessage`]. -#[derive(thiserror::Error)] -pub enum LeaveMessageTransformError { - /// Not enough bytes to decode LeaveMessage - #[error("not enough bytes to decode LeaveMessage")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - EncodeBufferTooSmall, - /// Error transforming Node - #[error(transparent)] - Id(I::Error), - /// Error transforming LamportTime - #[error(transparent)] - LamportTime(#[from] LamportTimeTransformError), -} - -impl core::fmt::Debug for LeaveMessageTransformError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -impl Transformable for LeaveMessage -where - I: Transformable, -{ - type Error = LeaveMessageTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::EncodeBufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32); - offset += 4; - dst[offset] = self.prune as u8; - offset += 1; - offset += self.ltime.encode(&mut dst[offset..])?; - offset += self - .id - .encode(&mut dst[offset..]) - .map_err(Self::Error::Id)?; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actual write {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + 1 + self.id.encoded_len() + self.ltime.encoded_len() - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - if src.len() < 5 { - return Err(Self::Error::NotEnoughBytes); - } - - let len = NetworkEndian::read_u32(&src[0..4]) as usize; - if src.len() + 5 < len { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 4; - let prune = src[offset] != 0; - offset += 1; - - let (read, ltime) = LamportTime::decode(&src[offset..])?; - offset += read; - - let (read, id) = I::decode(&src[offset..]).map_err(Self::Error::Id)?; - offset += read; - - debug_assert_eq!( - offset, len, - "expect read {} bytes, but actual read {} bytes", - len, offset - ); - - Ok((offset, Self { ltime, id, prune })) - } -} - -#[cfg(test)] -mod tests { - - use rand::{distributions::Alphanumeric, thread_rng, Rng}; - use smol_str::SmolStr; - - use super::*; - - impl LeaveMessage { - fn random(size: usize) -> Self { - let id = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let id = String::from_utf8(id).unwrap().into(); - - Self { - ltime: LamportTime::random(), - id, - prune: thread_rng().gen(), - } - } - } - - #[test] - fn test_leave_message_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let filter = LeaveMessage::random(i); - let mut buf = vec![0; filter.encoded_len()]; - let encoded_len = filter.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, filter.encoded_len()); - - let (decoded_len, decoded) = LeaveMessage::::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - LeaveMessage::::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - LeaveMessage::::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - } - }); - } -} diff --git a/types/src/member.rs b/types/src/member.rs deleted file mode 100644 index effe349..0000000 --- a/types/src/member.rs +++ /dev/null @@ -1,402 +0,0 @@ -use std::sync::Arc; - -use byteorder::{ByteOrder, NetworkEndian}; -use memberlist_types::CheapClone; - -use super::{ - DelegateVersion, MemberlistDelegateVersion, MemberlistProtocolVersion, Node, NodeTransformError, - ProtocolVersion, Tags, TagsTransformError, Transformable, UnknownDelegateVersion, - UnknownMemberlistDelegateVersion, UnknownMemberlistProtocolVersion, UnknownProtocolVersion, -}; - -/// The member status. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, bytemuck::NoUninit)] -#[repr(u8)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum MemberStatus { - /// None status - None = 0, - /// Alive status - Alive = 1, - /// Leaving status - Leaving = 2, - /// Left status - Left = 3, - /// Failed status - Failed = 4, -} - -impl core::fmt::Display for MemberStatus { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.as_str()) - } -} - -impl TryFrom for MemberStatus { - type Error = UnknownMemberStatus; - - fn try_from(value: u8) -> Result { - match value { - 0 => Ok(Self::None), - 1 => Ok(Self::Alive), - 2 => Ok(Self::Leaving), - 3 => Ok(Self::Left), - 4 => Ok(Self::Failed), - _ => Err(UnknownMemberStatus(value)), - } - } -} - -impl MemberStatus { - /// Get the string representation of the member status - #[inline] - pub const fn as_str(&self) -> &'static str { - match self { - Self::None => "none", - Self::Alive => "alive", - Self::Leaving => "leaving", - Self::Left => "left", - Self::Failed => "failed", - } - } -} - -/// Unknown member status -#[derive(Debug, Copy, Clone, thiserror::Error)] -#[error("Unknown member status: {0}")] -pub struct UnknownMemberStatus(u8); - -/// A single member of the Serf cluster. -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Member { - /// The node - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the node")), - setter(attrs(doc = "Sets the node (Builder pattern)")) - )] - node: Node, - /// The tags - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the tags")), - setter(attrs(doc = "Sets the tags (Builder pattern)")) - )] - tags: Arc, - /// The status - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the status")), - setter(attrs(doc = "Sets the status (Builder pattern)")) - )] - status: MemberStatus, - /// The memberlist protocol version - #[viewit( - getter(const, attrs(doc = "Returns the memberlist protocol version")), - setter( - const, - attrs(doc = "Sets the memberlist protocol version (Builder pattern)") - ) - )] - memberlist_protocol_version: MemberlistProtocolVersion, - /// The memberlist delegate version - #[viewit( - getter(const, attrs(doc = "Returns the memberlist delegate version")), - setter( - const, - attrs(doc = "Sets the memberlist delegate version (Builder pattern)") - ) - )] - memberlist_delegate_version: MemberlistDelegateVersion, - - /// The serf protocol version - #[viewit( - getter(const, attrs(doc = "Returns the serf protocol version")), - setter( - const, - attrs(doc = "Sets the serf protocol version (Builder pattern)") - ) - )] - protocol_version: ProtocolVersion, - /// The serf delegate version - #[viewit( - getter(const, attrs(doc = "Returns the serf delegate version")), - setter( - const, - attrs(doc = "Sets the serf delegate version (Builder pattern)") - ) - )] - delegate_version: DelegateVersion, -} - -impl Member { - /// Create a new member with the given node, tags, and status. - /// Other fields are set to their default values. - #[inline] - pub fn new(node: Node, tags: Tags, status: MemberStatus) -> Self { - Self { - node, - tags: Arc::new(tags), - status, - memberlist_protocol_version: MemberlistProtocolVersion::V1, - memberlist_delegate_version: MemberlistDelegateVersion::V1, - protocol_version: ProtocolVersion::V1, - delegate_version: DelegateVersion::V1, - } - } -} - -impl Clone for Member { - fn clone(&self) -> Self { - Self { - node: self.node.clone(), - tags: self.tags.clone(), - status: self.status, - memberlist_protocol_version: self.memberlist_protocol_version, - memberlist_delegate_version: self.memberlist_delegate_version, - protocol_version: self.protocol_version, - delegate_version: self.delegate_version, - } - } -} - -impl CheapClone for Member { - fn cheap_clone(&self) -> Self { - Self { - node: self.node.cheap_clone(), - tags: self.tags.cheap_clone(), - status: self.status, - memberlist_protocol_version: self.memberlist_protocol_version, - memberlist_delegate_version: self.memberlist_delegate_version, - protocol_version: self.protocol_version, - delegate_version: self.delegate_version, - } - } -} - -/// Error transforming the [`Member`] -#[derive(thiserror::Error)] -pub enum MemberTransformError -where - I: Transformable, - A: Transformable, -{ - /// Error transforming the `node` field - #[error(transparent)] - Node(#[from] NodeTransformError), - /// Error transforming the `tags` field - #[error(transparent)] - Tags(#[from] TagsTransformError), - /// Error transforming the `status` field - #[error(transparent)] - MemberStatus(#[from] UnknownMemberStatus), - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - /// Not enough bytes to decode - #[error("not enough bytes to decode `Member`")] - NotEnoughBytes, - - /// Error transforming the `memberlist_protocol_version` field - #[error(transparent)] - MemberlistProtocolVersion(#[from] UnknownMemberlistProtocolVersion), - - /// Error transforming the `memberlist_delegate_version` field - #[error(transparent)] - MemberlistDelegateVersion(#[from] UnknownMemberlistDelegateVersion), - - /// Error transforming the `protocol_version` field - #[error(transparent)] - ProtocolVersion(#[from] UnknownProtocolVersion), - - /// Error transforming the `delegate_version` field - #[error(transparent)] - DelegateVersion(#[from] UnknownDelegateVersion), -} - -impl core::fmt::Debug for MemberTransformError -where - I: Transformable, - A: Transformable, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -impl Transformable for Member -where - I: Transformable, - A: Transformable, -{ - type Error = MemberTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32); - offset += 4; - - offset += self.node.encode(&mut dst[offset..])?; - offset += self.tags.encode(&mut dst[offset..])?; - dst[offset] = self.status as u8; - offset += 1; - - dst[offset] = self.memberlist_protocol_version as u8; - offset += 1; - - dst[offset] = self.memberlist_delegate_version as u8; - offset += 1; - - dst[offset] = self.protocol_version as u8; - offset += 1; - - dst[offset] = self.delegate_version as u8; - offset += 1; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actually write {} bytes", - offset, encoded_len - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + self.node.encoded_len() - + self.tags.encoded_len() - + 1 // status - + 1 // memberlist_protocol_version - + 1 // memberlist_delegate_version - + 1 // protocol_version - + 1 // delegate_version - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - - if src_len < 4 { - return Err(Self::Error::NotEnoughBytes); - } - - let encoded_len = NetworkEndian::read_u32(&src[0..4]) as usize; - if src_len < encoded_len { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 4; - let (node_len, node) = Node::decode(&src[offset..])?; - offset += node_len; - - let (tags_len, tags) = Tags::decode(&src[offset..])?; - offset += tags_len; - - if src_len < offset + 5 { - return Err(Self::Error::NotEnoughBytes); - } - - let status = MemberStatus::try_from(src[offset])?; - offset += 1; - - let memberlist_protocol_version = MemberlistProtocolVersion::try_from(src[offset])?; - offset += 1; - - let memberlist_delegate_version = MemberlistDelegateVersion::try_from(src[offset])?; - offset += 1; - - let protocol_version = ProtocolVersion::try_from(src[offset])?; - offset += 1; - - let delegate_version = DelegateVersion::try_from(src[offset])?; - offset += 1; - - debug_assert_eq!( - offset, encoded_len, - "expect read {} bytes, but actually read {} bytes", - offset, encoded_len - ); - - Ok(( - encoded_len, - Self { - node, - tags: Arc::new(tags), - status, - memberlist_protocol_version, - memberlist_delegate_version, - protocol_version, - delegate_version, - }, - )) - } -} - -#[cfg(test)] -mod tests { - use std::net::SocketAddr; - - use rand::{distributions::Alphanumeric, random, thread_rng, Rng}; - use smol_str::SmolStr; - - use super::*; - - impl Member { - fn random(num_tags: usize, size: usize) -> Self { - let id = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let id = String::from_utf8(id).unwrap().into(); - let addr = SocketAddr::from(([127, 0, 0, 1], random::())); - let node = Node::new(id, addr); - let tags = Tags::random(num_tags, size); - - Self { - node, - tags: Arc::new(tags), - status: MemberStatus::Alive, - memberlist_protocol_version: MemberlistProtocolVersion::V1, - memberlist_delegate_version: MemberlistDelegateVersion::V1, - protocol_version: ProtocolVersion::V1, - delegate_version: DelegateVersion::V1, - } - } - } - - #[test] - fn member_encode_decode() { - futures::executor::block_on(async { - for i in 0..100 { - let filter = Member::random(i % 10, i); - let mut buf = vec![0; filter.encoded_len()]; - let encoded_len = filter.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, filter.encoded_len()); - - let (decoded_len, decoded) = Member::::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - Member::::decode_from_reader(&mut std::io::Cursor::new(&buf)) - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = Member::::decode_from_async_reader( - &mut futures::io::Cursor::new(&buf), - ) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - } - }); - } -} diff --git a/types/src/push_pull.rs b/types/src/push_pull.rs deleted file mode 100644 index 2f7b7b5..0000000 --- a/types/src/push_pull.rs +++ /dev/null @@ -1,467 +0,0 @@ -use byteorder::{ByteOrder, NetworkEndian}; -use indexmap::{IndexMap, IndexSet}; -use memberlist_types::TinyVec; -use transformable::Transformable; - -use super::{LamportTime, LamportTimeTransformError, UserEvents, UserEventsTransformError}; - -/// Used when doing a state exchange. This -/// is a relatively large message, but is sent infrequently -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr( - feature = "serde", - serde(bound( - serialize = "I: core::cmp::Eq + core::hash::Hash + serde::Serialize", - deserialize = "I: core::cmp::Eq + core::hash::Hash + serde::Deserialize<'de>" - )) -)] -pub struct PushPullMessage { - /// Current node lamport time - #[viewit( - getter(const, style = "move", attrs(doc = "Returns the lamport time")), - setter(const, attrs(doc = "Sets the lamport time (Builder pattern)")) - )] - ltime: LamportTime, - /// Maps the node to its status time - #[viewit( - getter( - const, - style = "ref", - attrs(doc = "Returns the maps the node to its status time") - ), - setter(attrs(doc = "Sets the maps the node to its status time (Builder pattern)")) - )] - status_ltimes: IndexMap, - /// List of left nodes - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the list of left nodes")), - setter(attrs(doc = "Sets the list of left nodes (Builder pattern)")) - )] - left_members: IndexSet, - /// Lamport time for event clock - #[viewit( - getter( - const, - style = "move", - attrs(doc = "Returns the lamport time for event clock") - ), - setter( - const, - attrs(doc = "Sets the lamport time for event clock (Builder pattern)") - ) - )] - event_ltime: LamportTime, - /// Recent events - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the recent events")), - setter(attrs(doc = "Sets the recent events (Builder pattern)")) - )] - events: TinyVec>, - /// Lamport time for query clock - #[viewit( - getter( - const, - style = "move", - attrs(doc = "Returns the lamport time for query clock") - ), - setter( - const, - attrs(doc = "Sets the lamport time for query clock (Builder pattern)") - ) - )] - query_ltime: LamportTime, -} - -impl PartialEq for PushPullMessage -where - I: core::hash::Hash + Eq, -{ - fn eq(&self, other: &Self) -> bool { - self.ltime == other.ltime - && self.status_ltimes == other.status_ltimes - && self.left_members == other.left_members - && self.event_ltime == other.event_ltime - && self.events == other.events - && self.query_ltime == other.query_ltime - } -} - -/// Used when doing a state exchange. This -/// is a relatively large message, but is sent infrequently -#[viewit::viewit(getters(skip), setters(skip))] -#[derive(Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize))] -pub struct PushPullMessageRef<'a, I> { - /// Current node lamport time - ltime: LamportTime, - /// Maps the node to its status time - status_ltimes: &'a IndexMap, - /// List of left nodes - left_members: &'a IndexSet, - /// Lamport time for event clock - event_ltime: LamportTime, - /// Recent events - events: &'a [Option], - /// Lamport time for query clock - query_ltime: LamportTime, -} - -impl Clone for PushPullMessageRef<'_, I> { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for PushPullMessageRef<'_, I> {} - -impl<'a, I> From<&'a PushPullMessage> for PushPullMessageRef<'a, I> { - #[inline] - fn from(msg: &'a PushPullMessage) -> Self { - Self { - ltime: msg.ltime, - status_ltimes: &msg.status_ltimes, - left_members: &msg.left_members, - event_ltime: msg.event_ltime, - events: &msg.events, - query_ltime: msg.query_ltime, - } - } -} - -impl<'a, I> From<&'a mut PushPullMessage> for PushPullMessageRef<'a, I> { - #[inline] - fn from(msg: &'a mut PushPullMessage) -> Self { - Self { - ltime: msg.ltime, - status_ltimes: &msg.status_ltimes, - left_members: &msg.left_members, - event_ltime: msg.event_ltime, - events: &msg.events, - query_ltime: msg.query_ltime, - } - } -} - -impl super::Encodable for PushPullMessageRef<'_, I> -where - I: Transformable, -{ - type Error = PushPullMessageTransformError; - - /// Returns the encoded length of the message - fn encoded_len(&self) -> usize { - 4 + Transformable::encoded_len(&self.ltime) - + 4 - + self - .status_ltimes - .iter() - .map(|(k, v)| Transformable::encoded_len(k) + Transformable::encoded_len(v)) - .sum::() - + 4 - + self - .left_members - .iter() - .map(Transformable::encoded_len) - .sum::() - + Transformable::encoded_len(&self.event_ltime) - + 4 - + self - .events - .iter() - .map(|e| match e { - Some(e) => 1 + Transformable::encoded_len(e), - None => 1, - }) - .sum::() - + Transformable::encoded_len(&self.query_ltime) - } - - /// Encodes the message into the given buffer - fn encode(&self, dst: &mut [u8]) -> Result> { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(PushPullMessageTransformError::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32); - offset += 4; - - offset += Transformable::encode(&self.ltime, &mut dst[offset..])?; - let len = self.status_ltimes.len() as u32; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], len); - offset += 4; - for (node, ltime) in self.status_ltimes.iter() { - offset += Transformable::encode(node, &mut dst[offset..]).map_err(Self::Error::Id)?; - offset += Transformable::encode(ltime, &mut dst[offset..])?; - } - - let len = self.left_members.len() as u32; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], len); - offset += 4; - for node in self.left_members.iter() { - offset += Transformable::encode(node, &mut dst[offset..]).map_err(Self::Error::Id)?; - } - - offset += Transformable::encode(&self.event_ltime, &mut dst[offset..])?; - let len = self.events.len() as u32; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], len); - offset += 4; - for e in self.events.iter() { - match e { - Some(e) => { - dst[offset] = 1; - offset += 1; - offset += Transformable::encode(e, &mut dst[offset..])?; - } - None => { - dst[offset] = 0; - offset += 1; - } - } - } - - offset += Transformable::encode(&self.query_ltime, &mut dst[offset..])?; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actual write {} bytes", - encoded_len, offset - ); - - Ok(offset) - } -} - -/// Error that can occur when transforming a [`PushPullMessage`] or [`PushPullMessageRef`]. -#[derive(thiserror::Error)] -pub enum PushPullMessageTransformError -where - I: Transformable, -{ - /// Not enough bytes to decode [`PushPullMessage`] - #[error("not enough bytes to decode PushPullMessage")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - /// Error transforming [`I`] - #[error(transparent)] - Id(I::Error), - /// Error when we do not have enough nodes - #[error("expect {expect} nodes, but actual decode {got} nodes")] - MissingLeftMember { - /// Expect - expect: usize, - /// Actual - got: usize, - }, - /// Error when we do not have enough status time - #[error("expect {expect} status time, but actual decode {got} status time")] - MissingNodeStatusTime { - /// Expect - expect: usize, - /// Actual - got: usize, - }, - /// Error transforming [`LamportTime`] - #[error(transparent)] - LamportTime(#[from] LamportTimeTransformError), - /// Error transforming [`UserEvents`] - #[error(transparent)] - UserEvents(#[from] UserEventsTransformError), - /// Error when we do not have enough events - #[error("expect {expect} events, but actual decode {got} events")] - MissingEvents { - /// Expect - expect: usize, - /// Actual - got: usize, - }, -} - -impl core::fmt::Debug for PushPullMessageTransformError -where - I: Transformable, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -impl Transformable for PushPullMessage -where - I: Transformable + core::hash::Hash + Eq, -{ - type Error = PushPullMessageTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - super::Encodable::encode(&PushPullMessageRef::from(self), dst) - } - - fn encoded_len(&self) -> usize { - super::Encodable::encoded_len(&PushPullMessageRef::from(self)) - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src_len < 4 { - return Err(PushPullMessageTransformError::NotEnoughBytes); - } - - let encoded_len = NetworkEndian::read_u32(&src[..4]) as usize; - if src_len < encoded_len { - return Err(PushPullMessageTransformError::NotEnoughBytes); - } - - let mut offset = 4; - let (n, ltime) = LamportTime::decode(&src[offset..])?; - offset += n; - - let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize; - offset += 4; - - let mut status_ltimes = IndexMap::with_capacity(len); - for _ in 0..len { - let (n, node) = I::decode(&src[offset..]).map_err(Self::Error::Id)?; - offset += n; - let (n, ltime) = LamportTime::decode(&src[offset..])?; - offset += n; - status_ltimes.insert(node, ltime); - } - - let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize; - offset += 4; - - let mut left_members = IndexSet::with_capacity(len); - for _ in 0..len { - let (n, node) = I::decode(&src[offset..]).map_err(Self::Error::Id)?; - offset += n; - left_members.insert(node); - } - - let (n, event_ltime) = LamportTime::decode(&src[offset..])?; - offset += n; - - let len = NetworkEndian::read_u32(&src[offset..offset + 4]) as usize; - offset += 4; - - let mut events = TinyVec::with_capacity(len); - for _ in 0..len { - let has_event = src[offset]; - offset += 1; - if has_event == 1 { - let (n, event) = UserEvents::decode(&src[offset..])?; - offset += n; - events.push(Some(event)); - } else { - events.push(None); - } - } - - let (n, query_ltime) = LamportTime::decode(&src[offset..])?; - offset += n; - - debug_assert_eq!( - offset, encoded_len, - "expect read {} bytes, but actual read {} bytes", - encoded_len, offset - ); - - Ok(( - encoded_len, - PushPullMessage { - ltime, - status_ltimes, - left_members, - event_ltime, - events, - query_ltime, - }, - )) - } -} - -#[cfg(test)] -mod tests { - use rand::{distributions::Alphanumeric, thread_rng, Rng}; - use smol_str::SmolStr; - - use super::*; - - impl PushPullMessage { - fn random(size: usize) -> Self { - let mut status_ltimes = IndexMap::new(); - for _ in 0..size { - let id = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let id = String::from_utf8(id).unwrap().into(); - - status_ltimes.insert(id, LamportTime::random()); - } - - let mut left_members = IndexSet::new(); - for _ in 0..size { - let id = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let id = String::from_utf8(id).unwrap().into(); - left_members.insert(id); - } - - let mut events = TinyVec::new(); - for i in 0..size { - if i % 2 == 0 { - events.push(None); - } else { - events.push(Some(UserEvents::random(size, size % 10))); - } - } - - Self { - ltime: LamportTime::random(), - status_ltimes, - left_members, - event_ltime: LamportTime::random(), - events, - query_ltime: LamportTime::random(), - } - } - } - - #[test] - fn test_push_pull_message_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let msg = PushPullMessage::random(i); - let mut buf = vec![0; msg.encoded_len()]; - let encoded_len = msg.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, msg.encoded_len()); - - let (decoded_len, decoded) = PushPullMessage::::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, msg); - - let (decoded_len, decoded) = - PushPullMessage::::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, msg); - - let (decoded_len, decoded) = - PushPullMessage::::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, msg); - } - }); - } -} diff --git a/types/src/query.rs b/types/src/query.rs deleted file mode 100644 index 04f06dc..0000000 --- a/types/src/query.rs +++ /dev/null @@ -1,631 +0,0 @@ -use byteorder::{ByteOrder, NetworkEndian}; -use smol_str::SmolStr; -use transformable::{ - BytesTransformError, DurationTransformError, StringTransformError, Transformable, -}; - -use std::time::Duration; - -use memberlist_types::{bytes::Bytes, Node, NodeTransformError, TinyVec}; - -use super::{LamportTime, LamportTimeTransformError}; - -bitflags::bitflags! { - /// Flags for query message - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] - #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] - #[cfg_attr(feature = "serde", serde(transparent))] - pub struct QueryFlag: u32 { - /// Ack flag is used to force receiver to send an ack back - const ACK = 1 << 0; - /// NoBroadcast is used to prevent re-broadcast of a query. - /// this can be used to selectively send queries to individual members - const NO_BROADCAST = 1 << 1; - } -} - -/// Query message -#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct QueryMessage { - /// Event lamport time - #[viewit( - getter(const, style = "move", attrs(doc = "Returns the event lamport time")), - setter(const, attrs(doc = "Sets the event lamport time (Builder pattern)")) - )] - ltime: LamportTime, - /// query id, randomly generated - #[viewit( - getter(const, style = "move", attrs(doc = "Returns the query id")), - setter(attrs(doc = "Sets the query id (Builder pattern)")) - )] - id: u32, - /// source node - #[viewit( - getter(const, attrs(doc = "Returns the from node")), - setter(attrs(doc = "Sets the from node (Builder pattern)")) - )] - from: Node, - /// Potential query filters - #[viewit( - getter(const, attrs(doc = "Returns the potential query filters")), - setter(attrs(doc = "Sets the potential query filters (Builder pattern)")) - )] - filters: TinyVec, - /// Used to provide various flags - #[viewit( - getter(const, style = "move", attrs(doc = "Returns the flags")), - setter(attrs(doc = "Sets the flags (Builder pattern)")) - )] - flags: QueryFlag, - /// Used to set the number of duplicate relayed responses - #[viewit( - getter( - const, - style = "move", - attrs(doc = "Returns the number of duplicate relayed responses") - ), - setter(attrs(doc = "Sets the number of duplicate relayed responses (Builder pattern)")) - )] - relay_factor: u8, - /// Maximum time between delivery and response - #[viewit( - getter( - const, - style = "move", - attrs(doc = "Returns the maximum time between delivery and response") - ), - setter(attrs(doc = "Sets the maximum time between delivery and response (Builder pattern)")) - )] - timeout: Duration, - /// Query nqme - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the name of the query")), - setter(attrs(doc = "Sets the name of the query (Builder pattern)")) - )] - name: SmolStr, - /// Query payload - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the payload")), - setter(attrs(doc = "Sets the payload (Builder pattern)")) - )] - payload: Bytes, -} - -impl QueryMessage { - /// Checks if the ack flag is set - #[inline] - pub fn ack(&self) -> bool { - self.flags.contains(QueryFlag::ACK) - } - - /// Checks if the no broadcast flag is set - #[inline] - pub fn no_broadcast(&self) -> bool { - self.flags.contains(QueryFlag::NO_BROADCAST) - } -} - -/// Error that can occur when transforming a [`QueryMessage`]. -#[derive(thiserror::Error)] -pub enum QueryMessageTransformError -where - I: Transformable, - A: Transformable, -{ - /// Not enough bytes to decode QueryMessage - #[error("not enough bytes to decode QueryMessage")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - /// Error transforming `from` field - #[error(transparent)] - From(#[from] NodeTransformError), - /// Error transforming `ltime` field - #[error(transparent)] - LamportTime(#[from] LamportTimeTransformError), - /// Error transforming `payload` field - #[error(transparent)] - Payload(BytesTransformError), - - /// Error transforming `filters` field - #[error(transparent)] - Filters(BytesTransformError), - - /// Error transforming `name` field - #[error(transparent)] - Name(#[from] StringTransformError), - - /// Error transforming `timeout` field - #[error(transparent)] - Timeout(#[from] DurationTransformError), -} - -impl core::fmt::Debug for QueryMessageTransformError -where - I: Transformable, - A: Transformable, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -impl Transformable for QueryMessage -where - I: Transformable, - A: Transformable, -{ - type Error = QueryMessageTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32); - offset += 4; - offset += self.ltime.encode(&mut dst[offset..])?; - NetworkEndian::write_u32(&mut dst[offset..], self.id); - offset += 4; - offset += self.from.encode(&mut dst[offset..])?; - NetworkEndian::write_u32(&mut dst[offset..], self.filters.len() as u32); - offset += 4; - for filter in self.filters.iter() { - offset += filter - .encode(&mut dst[offset..]) - .map_err(Self::Error::Filters)?; - } - NetworkEndian::write_u32(&mut dst[offset..], self.flags.bits()); - offset += 4; - dst[offset] = self.relay_factor; - offset += 1; - offset += self.timeout.encode(&mut dst[offset..])?; - offset += self.name.encode(&mut dst[offset..])?; - offset += self - .payload - .encode(&mut dst[offset..]) - .map_err(Self::Error::Payload)?; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actual write {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + self.ltime.encoded_len() - + 4 // id - + self.from.encoded_len() - + 4 // num filters - + self.filters.iter().map(|f| f.encoded_len()).sum::() - + 4 // flags - + 1 // relay_factor - + self.timeout.encoded_len() - + self.name.encoded_len() - + self.payload.encoded_len() - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src.len() < 4 { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 0; - let len = NetworkEndian::read_u32(&src[offset..]) as usize; - if src.len() < len { - return Err(Self::Error::NotEnoughBytes); - } - offset += 4; - - let (n, ltime) = LamportTime::decode(&src[offset..])?; - offset += n; - - if offset + 4 > src_len { - return Err(Self::Error::NotEnoughBytes); - } - - let id = NetworkEndian::read_u32(&src[offset..]); - offset += 4; - - let (n, from) = Node::decode(&src[offset..])?; - offset += n; - - if offset + 4 > src_len { - return Err(Self::Error::NotEnoughBytes); - } - - let num_filters = NetworkEndian::read_u32(&src[offset..]) as usize; - offset += 4; - - let mut filters = TinyVec::with_capacity(num_filters); - for _ in 0..num_filters { - let (n, filter) = Bytes::decode(&src[offset..]).map_err(Self::Error::Filters)?; - filters.push(filter); - offset += n; - } - - if offset + 4 > src_len { - return Err(Self::Error::NotEnoughBytes); - } - - let flags = QueryFlag::from_bits_retain(NetworkEndian::read_u32(&src[offset..])); - offset += 4; - - if offset + 1 > src_len { - return Err(Self::Error::NotEnoughBytes); - } - - let relay_factor = src[offset]; - offset += 1; - - let (n, timeout) = Duration::decode(&src[offset..])?; - offset += n; - - let (n, name) = SmolStr::decode(&src[offset..])?; - offset += n; - - let (n, payload) = Bytes::decode(&src[offset..]).map_err(Self::Error::Payload)?; - offset += n; - - debug_assert_eq!( - offset, len, - "expect read {} bytes, but actual read {} bytes", - len, offset - ); - - Ok(( - offset, - Self { - ltime, - id, - from, - filters, - flags, - relay_factor, - timeout, - name, - payload, - }, - )) - } -} - -/// Query response message -#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct QueryResponseMessage { - /// Event lamport time - #[viewit( - getter(const, attrs(doc = "Returns the lamport time for this message")), - setter( - const, - attrs(doc = "Sets the lamport time for this message (Builder pattern)") - ) - )] - ltime: LamportTime, - /// query id - #[viewit( - getter(const, attrs(doc = "Returns the query id")), - setter(attrs(doc = "Sets the query id (Builder pattern)")) - )] - id: u32, - /// node - #[viewit( - getter(const, attrs(doc = "Returns the from node")), - setter(attrs(doc = "Sets the from node (Builder pattern)")) - )] - from: Node, - /// Used to provide various flags - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the flags")), - setter(attrs(doc = "Sets the flags (Builder pattern)")) - )] - flags: QueryFlag, - /// Optional response payload - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the payload")), - setter(attrs(doc = "Sets the payload (Builder pattern)")) - )] - payload: Bytes, -} - -impl QueryResponseMessage { - /// Checks if the ack flag is set - #[inline] - pub fn ack(&self) -> bool { - self.flags.contains(QueryFlag::ACK) - } - - /// Checks if the no broadcast flag is set - #[inline] - pub fn no_broadcast(&self) -> bool { - self.flags.contains(QueryFlag::NO_BROADCAST) - } -} - -/// Error that can occur when transforming a [`QueryResponseMessage`]. -#[derive(thiserror::Error)] -pub enum QueryResponseMessageTransformError -where - I: Transformable, - A: Transformable, -{ - /// Not enough bytes to decode QueryResponseMessage - #[error("not enough bytes to decode QueryResponseMessage")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - /// Error transforming Node - #[error(transparent)] - Node(#[from] NodeTransformError), - /// Error transforming LamportTime - #[error(transparent)] - LamportTime(#[from] LamportTimeTransformError), - /// Error transforming payload - #[error(transparent)] - Payload(#[from] BytesTransformError), -} - -impl core::fmt::Debug for QueryResponseMessageTransformError -where - I: Transformable, - A: Transformable, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -impl Transformable for QueryResponseMessage -where - I: Transformable, - A: Transformable, -{ - type Error = QueryResponseMessageTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32); - offset += 4; - offset += self.ltime.encode(&mut dst[offset..])?; - NetworkEndian::write_u32(&mut dst[offset..], self.id); - offset += 4; - offset += self.from.encode(&mut dst[offset..])?; - NetworkEndian::write_u32(&mut dst[offset..], self.flags.bits()); - offset += 4; - offset += self.payload.encode(&mut dst[offset..])?; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actual write {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + self.ltime.encoded_len() + 4 + self.from.encoded_len() + 4 + self.payload.encoded_len() - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src.len() < 4 { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 0; - let len = NetworkEndian::read_u32(&src[offset..]) as usize; - if src.len() < len { - return Err(Self::Error::NotEnoughBytes); - } - - offset += 4; - let (n, ltime) = LamportTime::decode(&src[offset..])?; - offset += n; - - if offset + 4 > src_len { - return Err(Self::Error::NotEnoughBytes); - } - let id = NetworkEndian::read_u32(&src[offset..]); - offset += 4; - - let (n, from) = Node::decode(&src[offset..])?; - offset += n; - - if offset + 4 > src_len { - return Err(Self::Error::NotEnoughBytes); - } - - let flags = QueryFlag::from_bits_retain(NetworkEndian::read_u32(&src[offset..])); - offset += 4; - - let (n, payload) = Bytes::decode(&src[offset..])?; - offset += n; - - debug_assert_eq!( - offset, len, - "expect read {} bytes, but actual read {} bytes", - len, offset - ); - - Ok(( - offset, - Self { - ltime, - id, - from, - flags, - payload, - }, - )) - } -} - -#[cfg(test)] -mod tests { - use std::net::SocketAddr; - - use rand::{distributions::Alphanumeric, random, thread_rng, Rng}; - - use super::*; - - impl QueryMessage { - fn random(size: usize, num_filters: usize) -> Self { - let ltime = LamportTime::random(); - let id = random(); - let from_id = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let from_id = String::from_utf8(from_id).unwrap().into(); - let addr = SocketAddr::from(([127, 0, 0, 1], random::())); - let from = Node::new(from_id, addr); - let filters = (0..num_filters) - .map(|_| { - let payload = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - payload.into() - }) - .collect(); - let flags = QueryFlag::empty(); - let relay_factor = random(); - let timeout = Duration::from_secs(random::()); - let name = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let name = SmolStr::from(String::from_utf8(name).unwrap()); - let payload = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let payload = Bytes::from(payload); - Self { - ltime, - id, - from, - filters, - flags, - relay_factor, - timeout, - name, - payload, - } - } - } - - impl QueryResponseMessage { - fn random(size: usize) -> Self { - let id = rand::random(); - - let from_id = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - let from_id = String::from_utf8(from_id).unwrap().into(); - let addr = SocketAddr::from(([127, 0, 0, 1], random::())); - let from = Node::new(from_id, addr); - let flags = QueryFlag::empty(); - let payload = thread_rng() - .sample_iter(Alphanumeric) - .take(size) - .collect::>(); - Self { - ltime: LamportTime::random(), - id, - from, - flags, - payload: payload.into(), - } - } - } - - #[test] - fn test_query_response_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let filter = QueryResponseMessage::random(i); - let mut buf = vec![0; filter.encoded_len()]; - let encoded_len = filter.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, filter.encoded_len()); - - let (decoded_len, decoded) = - QueryResponseMessage::::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - QueryResponseMessage::::decode_from_reader( - &mut std::io::Cursor::new(&buf), - ) - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - QueryResponseMessage::::decode_from_async_reader( - &mut futures::io::Cursor::new(&buf), - ) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - } - }); - } - - #[test] - fn test_query_message_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let filter = QueryMessage::random(i, i % 10); - let mut buf = vec![0; filter.encoded_len()]; - let encoded_len = filter.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, filter.encoded_len()); - - let (decoded_len, decoded) = QueryMessage::::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - QueryMessage::::decode_from_reader(&mut std::io::Cursor::new(&buf)) - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = QueryMessage::::decode_from_async_reader( - &mut futures::io::Cursor::new(&buf), - ) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - } - }); - } -} diff --git a/types/src/tags.rs b/types/src/tags.rs deleted file mode 100644 index f65df74..0000000 --- a/types/src/tags.rs +++ /dev/null @@ -1,202 +0,0 @@ -use byteorder::{ByteOrder, NetworkEndian}; -use indexmap::IndexMap; -use smol_str::SmolStr; -use transformable::Transformable; - -/// Tags of a node -#[derive( - Debug, - Default, - PartialEq, - Clone, - derive_more::From, - derive_more::Into, - derive_more::Deref, - derive_more::DerefMut, -)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -pub struct Tags(IndexMap); - -impl IntoIterator for Tags { - type Item = (SmolStr, SmolStr); - type IntoIter = indexmap::map::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl FromIterator<(SmolStr, SmolStr)> for Tags { - fn from_iter>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - -impl<'a> FromIterator<(&'a str, &'a str)> for Tags { - fn from_iter>(iter: T) -> Self { - Self( - iter - .into_iter() - .map(|(k, v)| (SmolStr::new(k), SmolStr::new(v))) - .collect(), - ) - } -} - -impl Tags { - /// Create a new Tags - #[inline] - pub fn new() -> Self { - Self(IndexMap::new()) - } - - /// Create a new Tags with a capacity - pub fn with_capacity(cap: usize) -> Self { - Self(IndexMap::with_capacity(cap)) - } -} - -/// Error that can occur when transforming [`Tags`]. -#[derive(Debug, thiserror::Error)] -pub enum TagsTransformError { - /// Not enough bytes to decode Tags - #[error("not enough bytes to decode `Tags`")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - /// Error transforming a string - #[error(transparent)] - String(#[from] transformable::StringTransformError), -} - -impl Transformable for Tags { - type Error = TagsTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], encoded_len as u32); - offset += 4; - let len = self.0.len() as u32; - NetworkEndian::write_u32(&mut dst[offset..offset + 4], len); - offset += 4; - for (key, value) in self.0.iter() { - offset += key.encode(&mut dst[offset..])?; - offset += value.encode(&mut dst[offset..])?; - } - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actual write {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + 4 - + self - .0 - .iter() - .map(|(key, value)| key.encoded_len() + value.encoded_len()) - .sum::() - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src_len < 8 { - return Err(Self::Error::NotEnoughBytes); - } - - let len = NetworkEndian::read_u32(&src[0..4]) as usize; - if src_len < len { - return Err(Self::Error::NotEnoughBytes); - } - - let total_tags = NetworkEndian::read_u32(&src[4..8]) as usize; - let mut offset = 8; - let mut tags = IndexMap::with_capacity(total_tags); - for _ in 0..total_tags { - let (n, key) = SmolStr::decode(&src[offset..])?; - offset += n; - let (n, value) = SmolStr::decode(&src[offset..])?; - offset += n; - tags.insert(key, value); - } - - debug_assert_eq!( - len, offset, - "expected read {} bytes, but actual read {} bytes", - len, offset - ); - - Ok((offset, Self(tags))) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rand::{distributions::Alphanumeric, Rng}; - - impl Tags { - pub(crate) fn random(num_tags: usize, size: usize) -> Self { - let mut tags = IndexMap::with_capacity(num_tags); - for _ in 0..num_tags { - let rng = rand::thread_rng(); - let name = rng - .sample_iter(&Alphanumeric) - .take(size) - .collect::>(); - let name = String::from_utf8(name).unwrap(); - - let rng = rand::thread_rng(); - let payload = rng - .sample_iter(&Alphanumeric) - .take(size) - .collect::>(); - - tags.insert(name.into(), String::from_utf8(payload).unwrap().into()); - } - Self(tags) - } - } - - #[test] - fn test_tags_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let event = Tags::random(i % 10, i); - let mut buf = vec![0; event.encoded_len()]; - let encoded_len = event.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, event.encoded_len()); - - let (decoded_len, decoded) = Tags::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - - let (decoded_len, decoded) = - Tags::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - - let (decoded_len, decoded) = - Tags::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - } - }); - } -} diff --git a/types/src/user_event.rs b/types/src/user_event.rs deleted file mode 100644 index 9dbb5e5..0000000 --- a/types/src/user_event.rs +++ /dev/null @@ -1,525 +0,0 @@ -use byteorder::{ByteOrder, NetworkEndian}; -use memberlist_types::{bytes::Bytes, CheapClone, OneOrMore}; -use smol_str::SmolStr; -use transformable::{BytesTransformError, StringTransformError, Transformable}; - -use super::{LamportTime, LamportTimeTransformError}; - -/// Used to buffer events to prevent re-delivery -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct UserEvents { - /// The lamport time - #[viewit( - getter(const, attrs(doc = "Returns the lamport time for this message")), - setter( - const, - attrs(doc = "Sets the lamport time for this message (Builder pattern)") - ) - )] - ltime: LamportTime, - - /// The user events - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the user events")), - setter(attrs(doc = "Sets the user events (Builder pattern)")) - )] - events: OneOrMore, -} - -/// Error that can occur when transforming a [`UserEvents`] -#[derive(Debug, thiserror::Error)] -pub enum UserEventsTransformError { - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - /// Not enough bytes to decode [`UserEvents`] - #[error("not enough bytes to decode `UserEvents`")] - NotEnoughBytes, - /// Error transforming [`UserEvent`] - #[error(transparent)] - Event(#[from] UserEventTransformError), - /// Error transforming [`LamportTime`] - #[error(transparent)] - LamportTime(#[from] LamportTimeTransformError), -} - -impl Transformable for UserEvents { - type Error = UserEventsTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32); - offset += 4; - - offset += self.ltime.encode(&mut dst[offset..])?; - NetworkEndian::write_u32(&mut dst[offset..], self.events.len() as u32); - offset += 4; - - for event in self.events.iter() { - offset += event.encode(&mut dst[offset..])?; - } - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, actual read {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + self.ltime.encoded_len() - + 4 - + self - .events - .iter() - .map(UserEvent::encoded_len) - .sum::() - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src_len < 4 { - return Err(Self::Error::NotEnoughBytes); - } - - let len = NetworkEndian::read_u32(&src[0..4]) as usize; - if src_len < len { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 4; - let (ltime_offset, ltime) = LamportTime::decode(&src[offset..])?; - offset += ltime_offset; - - let event_len = NetworkEndian::read_u32(&src[offset..]) as usize; - offset += 4; - - let mut events = OneOrMore::with_capacity(event_len); - for _ in 0..event_len { - let (event_offset, event) = UserEvent::decode(&src[offset..])?; - offset += event_offset; - events.push(event); - } - - debug_assert_eq!( - offset, len, - "expect read {} bytes, actual read {} bytes", - len, offset - ); - - Ok((len, Self { ltime, events })) - } -} - -/// Stores all the user events at a specific time -#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct UserEvent { - /// The name of the event - #[viewit( - getter(const, attrs(doc = "Returns the name of the event")), - setter(attrs(doc = "Sets the name of the event (Builder pattern)")) - )] - name: SmolStr, - /// The payload of the event - #[viewit( - getter(const, attrs(doc = "Returns the payload of the event")), - setter(attrs(doc = "Sets the payload of the event (Builder pattern)")) - )] - payload: Bytes, -} - -/// Error that can occur when transforming a [`UserEvent`] -#[derive(Debug, thiserror::Error)] -pub enum UserEventTransformError { - /// Not enough bytes to decode UserEvent - #[error("not enough bytes to decode `UserEvent`")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - - /// Error transforming SmolStr - #[error(transparent)] - Name(#[from] StringTransformError), - - /// Error transforming Bytes - #[error(transparent)] - Payload(#[from] BytesTransformError), -} - -impl Transformable for UserEvent { - type Error = UserEventTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32); - offset += 4; - - offset += self.name.encode(&mut dst[offset..])?; - offset += self.payload.encode(&mut dst[offset..])?; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, actual read {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + self.name.encoded_len() + self.payload.encoded_len() - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src_len < 4 { - return Err(Self::Error::NotEnoughBytes); - } - - let len = NetworkEndian::read_u32(&src[0..4]) as usize; - if src_len < len { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 4; - let (name_offset, name) = SmolStr::decode(&src[offset..])?; - offset += name_offset; - let (payload_offset, payload) = Bytes::decode(&src[offset..])?; - offset += payload_offset; - - debug_assert_eq!( - offset, len, - "expect read {} bytes, actual read {} bytes", - len, offset - ); - - Ok((len, Self { name, payload })) - } -} - -/// Used for user-generated events -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Default, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct UserEventMessage { - /// The lamport time - #[viewit( - getter( - const, - style = "move", - attrs(doc = "Returns the lamport time for this message") - ), - setter( - const, - attrs(doc = "Sets the lamport time for this message (Builder pattern)") - ) - )] - ltime: LamportTime, - /// The name of the event - #[viewit( - getter(const, attrs(doc = "Returns the name of the event")), - setter(attrs(doc = "Sets the name of the event (Builder pattern)")) - )] - name: SmolStr, - /// The payload of the event - #[viewit( - getter(const, attrs(doc = "Returns the payload of the event")), - setter(attrs(doc = "Sets the payload of the event (Builder pattern)")) - )] - payload: Bytes, - /// "Can Coalesce". - #[viewit( - getter( - const, - style = "move", - attrs(doc = "Returns if this message can be coalesced") - ), - setter( - const, - attrs(doc = "Sets if this message can be coalesced (Builder pattern)") - ) - )] - cc: bool, -} - -impl CheapClone for UserEventMessage { - fn cheap_clone(&self) -> Self { - Self { - ltime: self.ltime, - name: self.name.cheap_clone(), - payload: self.payload.clone(), - cc: self.cc, - } - } -} - -/// Error that can occur when transforming a [`UserEventMessage`] -#[derive(Debug, thiserror::Error)] -pub enum UserEventMessageTransformError { - /// Not enough bytes to decode UserEventMessage - #[error("not enough bytes to decode `UserEventMessage`")] - NotEnoughBytes, - /// Encode buffer too small - #[error("encode buffer too small")] - BufferTooSmall, - - /// Error transforming LamportTime - #[error(transparent)] - LamportTime(#[from] LamportTimeTransformError), - - /// Error transforming SmolStr - #[error(transparent)] - Name(#[from] StringTransformError), - - /// Error transforming Bytes - #[error(transparent)] - Payload(#[from] BytesTransformError), -} - -impl Transformable for UserEventMessage { - type Error = UserEventMessageTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32); - offset += 4; - dst[offset] = self.cc as u8; - offset += 1; - offset += self.ltime.encode(&mut dst[offset..])?; - offset += self.name.encode(&mut dst[offset..])?; - offset += self.payload.encode(&mut dst[offset..])?; - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, actual read {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + self.ltime.encoded_len() + self.name.encoded_len() + self.payload.encoded_len() + 1 - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src_len < 4 { - return Err(Self::Error::NotEnoughBytes); - } - - let len = NetworkEndian::read_u32(&src[0..4]) as usize; - if src_len < len { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 4; - let cc = src[offset] != 0; - offset += 1; - let (ltime_offset, ltime) = LamportTime::decode(&src[offset..])?; - offset += ltime_offset; - let (name_offset, name) = SmolStr::decode(&src[offset..])?; - offset += name_offset; - let (payload_offset, payload) = Bytes::decode(&src[offset..])?; - offset += payload_offset; - - debug_assert_eq!( - offset, len, - "expect read {} bytes, actual read {} bytes", - len, offset - ); - - Ok(( - len, - Self { - ltime, - name, - payload, - cc, - }, - )) - } -} - -#[cfg(test)] -mod tests { - use rand::{distributions::Alphanumeric, random, Rng}; - - use super::*; - - impl UserEvent { - fn random(size: usize) -> Self { - let rng = rand::thread_rng(); - let name = rng - .sample_iter(&Alphanumeric) - .take(size) - .collect::>(); - let name = String::from_utf8(name).unwrap(); - - let rng = rand::thread_rng(); - let payload = rng - .sample_iter(&Alphanumeric) - .take(size) - .collect::>(); - - Self { - name: name.into(), - payload: payload.into(), - } - } - } - - impl UserEvents { - pub(crate) fn random(size: usize, num_events: usize) -> Self { - let mut events = OneOrMore::with_capacity(num_events); - for _ in 0..num_events { - events.push(UserEvent::random(size)); - } - - Self { - ltime: LamportTime::random(), - events, - } - } - } - - impl UserEventMessage { - fn random(size: usize) -> Self { - let rng = rand::thread_rng(); - let name = rng - .sample_iter(&Alphanumeric) - .take(size) - .collect::>(); - let name = String::from_utf8(name).unwrap(); - - let rng = rand::thread_rng(); - let payload = rng - .sample_iter(&Alphanumeric) - .take(size) - .collect::>(); - - Self { - ltime: LamportTime::random(), - name: name.into(), - payload: payload.into(), - cc: random(), - } - } - } - - #[test] - fn test_user_event_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let event = UserEvent::random(i); - let mut buf = vec![0; event.encoded_len()]; - let encoded_len = event.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, event.encoded_len()); - - let (decoded_len, decoded) = UserEvent::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - - let (decoded_len, decoded) = - UserEvent::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - - let (decoded_len, decoded) = - UserEvent::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - } - }) - } - - #[test] - fn test_user_events_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let events = UserEvents::random(i, i % 10); - let mut buf = vec![0; events.encoded_len()]; - let encoded_len = events.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, events.encoded_len()); - - let (decoded_len, decoded) = UserEvents::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, events); - - let (decoded_len, decoded) = - UserEvents::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, events); - - let (decoded_len, decoded) = - UserEvents::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, events); - } - }) - } - - #[test] - fn test_user_event_message_transform() { - futures::executor::block_on(async { - for i in 0..100 { - let event = UserEventMessage::random(i); - let mut buf = vec![0; event.encoded_len()]; - let encoded_len = event.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, event.encoded_len()); - - let (decoded_len, decoded) = UserEventMessage::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - - let (decoded_len, decoded) = - UserEventMessage::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - - let (decoded_len, decoded) = - UserEventMessage::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, event); - } - }) - } -} diff --git a/types/src/version.rs b/types/src/version.rs deleted file mode 100644 index 08f270b..0000000 --- a/types/src/version.rs +++ /dev/null @@ -1,152 +0,0 @@ -/// Unknown delegate version -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, thiserror::Error)] -#[error("V{0} is not a valid delegate version")] -pub struct UnknownDelegateVersion(u8); - -/// Delegate version -#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -// #[cfg_attr( -// feature = "rkyv", -// derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive) -// )] -// #[cfg_attr(feature = "rkyv", archive(compare(PartialEq), check_bytes))] -// #[cfg_attr( -// feature = "rkyv", -// archive_attr( -// derive(Debug, Copy, Clone, Eq, PartialEq, Hash), -// repr(u8), -// non_exhaustive -// ) -// )] -#[non_exhaustive] -#[repr(u8)] -pub enum DelegateVersion { - /// Version 1 - #[default] - V1 = 1, -} - -impl core::fmt::Display for DelegateVersion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - DelegateVersion::V1 => write!(f, "V1"), - } - } -} - -impl TryFrom for DelegateVersion { - type Error = UnknownDelegateVersion; - fn try_from(v: u8) -> Result { - match v { - 1 => Ok(DelegateVersion::V1), - _ => Err(UnknownDelegateVersion(v)), - } - } -} - -#[cfg(feature = "rkyv")] -const _: () = { - impl From for DelegateVersion { - fn from(value: ArchivedDelegateVersion) -> Self { - match value { - ArchivedDelegateVersion::V1 => Self::V1, - } - } - } - - impl From for ArchivedDelegateVersion { - fn from(value: DelegateVersion) -> Self { - match value { - DelegateVersion::V1 => Self::V1, - } - } - } -}; - -/// Unknown protocol version -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, thiserror::Error)] -#[error("V{0} is not a valid protocol version")] -pub struct UnknownProtocolVersion(u8); - -/// Protocol version -#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -// #[cfg_attr( -// feature = "rkyv", -// derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive) -// )] -// #[cfg_attr(feature = "rkyv", archive(compare(PartialEq), check_bytes))] -// #[cfg_attr( -// feature = "rkyv", -// archive_attr( -// derive(Debug, Copy, Clone, Eq, PartialEq, Hash), -// repr(u8), -// non_exhaustive -// ) -// )] -#[non_exhaustive] -#[repr(u8)] -pub enum ProtocolVersion { - /// Version 1 - #[default] - V1 = 1, -} - -impl core::fmt::Display for ProtocolVersion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::V1 => write!(f, "V1"), - } - } -} - -impl TryFrom for ProtocolVersion { - type Error = UnknownProtocolVersion; - fn try_from(v: u8) -> Result { - match v { - 1 => Ok(Self::V1), - _ => Err(UnknownProtocolVersion(v)), - } - } -} - -#[cfg(feature = "rkyv")] -const _: () = { - impl From for ProtocolVersion { - fn from(value: ArchivedProtocolVersion) -> Self { - match value { - ArchivedProtocolVersion::V1 => Self::V1, - } - } - } - - impl From for ArchivedProtocolVersion { - fn from(value: ProtocolVersion) -> Self { - match value { - ProtocolVersion::V1 => Self::V1, - } - } - } -}; - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_delegate_version() { - assert_eq!(DelegateVersion::V1 as u8, 1); - assert_eq!(DelegateVersion::V1.to_string(), "V1"); - assert_eq!(DelegateVersion::try_from(1), Ok(DelegateVersion::V1)); - assert_eq!(DelegateVersion::try_from(0), Err(UnknownDelegateVersion(0))); - } - - #[test] - fn test_protocol_version() { - assert_eq!(ProtocolVersion::V1 as u8, 1); - assert_eq!(ProtocolVersion::V1.to_string(), "V1"); - assert_eq!(ProtocolVersion::try_from(1), Ok(ProtocolVersion::V1)); - assert_eq!(ProtocolVersion::try_from(0), Err(UnknownProtocolVersion(0))); - } -} From 3311418dd1de3ae597671daf319b1718f005ed93 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 25 Feb 2025 11:11:31 +0800 Subject: [PATCH 02/39] WIP --- serf-proto/src/arbitrary_impl.rs | 19 ++++++++++++++----- serf-proto/src/clock.rs | 25 +++++++++++++------------ serf-proto/src/filter.rs | 1 - serf-proto/src/leave.rs | 1 - serf-proto/src/lib.rs | 4 ++-- serf-proto/src/member.rs | 10 +++++----- serf-proto/src/push_pull.rs | 6 +++++- serf-proto/src/query.rs | 9 +++++---- serf-proto/src/tags.rs | 6 ++++-- serf-proto/src/user_event.rs | 3 +-- serf-proto/src/version.rs | 8 ++++++-- 11 files changed, 55 insertions(+), 37 deletions(-) diff --git a/serf-proto/src/arbitrary_impl.rs b/serf-proto/src/arbitrary_impl.rs index acb21ba..2dabfaa 100644 --- a/serf-proto/src/arbitrary_impl.rs +++ b/serf-proto/src/arbitrary_impl.rs @@ -1,4 +1,7 @@ -use std::{collections::{HashMap, HashSet}, hash::Hash}; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, +}; use super::Filter; use arbitrary::{Arbitrary, Unstructured}; @@ -13,7 +16,9 @@ where u.arbitrary::().map(Into::into) } -pub(super) fn arbitrary_indexmap<'a, K, V>(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result> +pub(super) fn arbitrary_indexmap<'a, K, V>( + u: &mut arbitrary::Unstructured<'a>, +) -> arbitrary::Result> where K: Arbitrary<'a> + Hash + Eq, V: Arbitrary<'a>, @@ -22,7 +27,9 @@ where Ok(IndexMap::from_iter(map)) } -pub(super) fn arbitrary_indexset<'a, K>(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result> +pub(super) fn arbitrary_indexset<'a, K>( + u: &mut arbitrary::Unstructured<'a>, +) -> arbitrary::Result> where K: Arbitrary<'a> + Hash + Eq, { @@ -30,7 +37,6 @@ where Ok(IndexSet::from_iter(map)) } - impl<'a, I> Arbitrary<'a> for Filter where I: Arbitrary<'a>, @@ -40,7 +46,10 @@ where Ok(if kind { Filter::Id(into::, TinyVec<_>>(u)?) } else { - Filter::Tag { tag: u.arbitrary()?, expr: u.arbitrary()? } + Filter::Tag { + tag: u.arbitrary()?, + expr: u.arbitrary()?, + } }) } } diff --git a/serf-proto/src/clock.rs b/serf-proto/src/clock.rs index b1b7143..84b8ffe 100644 --- a/serf-proto/src/clock.rs +++ b/serf-proto/src/clock.rs @@ -94,21 +94,22 @@ impl core::ops::Rem for LamportTime { } impl Data for LamportTime { - type Ref<'a> = Self; + type Ref<'a> = Self; - fn from_ref(val: Self::Ref<'_>) -> Result - where - Self: Sized { - Ok(val) - } + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(val) + } - fn encoded_len(&self) -> usize { - ::encoded_len(&self.0) - } + fn encoded_len(&self) -> usize { + ::encoded_len(&self.0) + } - fn encode(&self, buf: &mut [u8]) -> Result { - ::encode(&self.0, buf) - } + fn encode(&self, buf: &mut [u8]) -> Result { + ::encode(&self.0, buf) + } } impl<'a> DataRef<'a, LamportTime> for LamportTime { diff --git a/serf-proto/src/filter.rs b/serf-proto/src/filter.rs index b2c27cb..1715bfb 100644 --- a/serf-proto/src/filter.rs +++ b/serf-proto/src/filter.rs @@ -77,4 +77,3 @@ impl Filter { } } } - diff --git a/serf-proto/src/leave.rs b/serf-proto/src/leave.rs index ba3f2e4..65df594 100644 --- a/serf-proto/src/leave.rs +++ b/serf-proto/src/leave.rs @@ -30,4 +30,3 @@ pub struct LeaveMessage { )] prune: bool, } - diff --git a/serf-proto/src/lib.rs b/serf-proto/src/lib.rs index 80b0bdc..440848f 100644 --- a/serf-proto/src/lib.rs +++ b/serf-proto/src/lib.rs @@ -7,8 +7,8 @@ #![cfg_attr(docsrs, allow(unused_attributes))] pub use memberlist_proto::{ - DelegateVersion as MemberlistDelegateVersion, Node, NodeId, HostAddr, ParseDomainError, ParseHostAddrError, Domain, - ProtocolVersion as MemberlistProtocolVersion, + DelegateVersion as MemberlistDelegateVersion, Domain, HostAddr, Node, NodeId, ParseDomainError, + ParseHostAddrError, ProtocolVersion as MemberlistProtocolVersion, }; #[cfg(feature = "arbitrary")] diff --git a/serf-proto/src/member.rs b/serf-proto/src/member.rs index 23405ac..8b603f1 100644 --- a/serf-proto/src/member.rs +++ b/serf-proto/src/member.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use memberlist_proto::CheapClone; use super::{ - DelegateVersion, MemberlistDelegateVersion, MemberlistProtocolVersion, Node, - ProtocolVersion, Tags, + DelegateVersion, MemberlistDelegateVersion, MemberlistProtocolVersion, Node, ProtocolVersion, + Tags, }; const MEMBER_STATUS_NONE: u8 = 0; @@ -13,9 +13,10 @@ const MEMBER_STATUS_LEAVING: u8 = 2; const MEMBER_STATUS_LEFT: u8 = 3; const MEMBER_STATUS_FAILED: u8 = 4; - /// The member status. -#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, derive_more::IsVariant, derive_more::Display)] +#[derive( + Debug, Default, Copy, Clone, Eq, PartialEq, Hash, derive_more::IsVariant, derive_more::Display, +)] #[repr(u8)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] @@ -83,7 +84,6 @@ impl MemberStatus { } } - /// A single member of the Serf cluster. #[viewit::viewit(setters(prefix = "with"))] #[derive(Debug, PartialEq)] diff --git a/serf-proto/src/push_pull.rs b/serf-proto/src/push_pull.rs index 6d52f19..76823c9 100644 --- a/serf-proto/src/push_pull.rs +++ b/serf-proto/src/push_pull.rs @@ -1,7 +1,6 @@ use indexmap::{IndexMap, IndexSet}; use memberlist_proto::TinyVec; - use super::{LamportTime, UserEvents}; /// Used when doing a state exchange. This @@ -17,6 +16,10 @@ use super::{LamportTime, UserEvents}; )) )] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[cfg_attr( + feature = "arbitrary", + arbitrary(bound = "I: arbitrary::Arbitrary<'arbitrary> + core::cmp::Eq + core::hash::Hash") +)] pub struct PushPullMessage { /// Current node lamport time #[viewit( @@ -60,6 +63,7 @@ pub struct PushPullMessage { getter(const, style = "ref", attrs(doc = "Returns the recent events")), setter(attrs(doc = "Sets the recent events (Builder pattern)")) )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::>, TinyVec>>))] events: TinyVec>, /// Lamport time for query clock #[viewit( diff --git a/serf-proto/src/query.rs b/serf-proto/src/query.rs index 349f71e..55a1067 100644 --- a/serf-proto/src/query.rs +++ b/serf-proto/src/query.rs @@ -1,11 +1,10 @@ use smol_str::SmolStr; - use std::time::Duration; use memberlist_proto::{Node, TinyVec, bytes::Bytes}; -use super::LamportTime; +use super::{Filter, LamportTime}; bitflags::bitflags! { /// Flags for query message @@ -50,8 +49,8 @@ pub struct QueryMessage { getter(const, attrs(doc = "Returns the potential query filters")), setter(attrs(doc = "Sets the potential query filters (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, TinyVec>))] - filters: TinyVec, + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::>, TinyVec>>))] + filters: TinyVec>, /// Used to provide various flags #[viewit( getter(const, style = "move", attrs(doc = "Returns the flags")), @@ -111,6 +110,7 @@ impl QueryMessage { #[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] #[derive(Debug, Clone, Eq, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct QueryResponseMessage { /// Event lamport time #[viewit( @@ -144,6 +144,7 @@ pub struct QueryResponseMessage { getter(const, style = "ref", attrs(doc = "Returns the payload")), setter(attrs(doc = "Sets the payload (Builder pattern)")) )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] payload: Bytes, } diff --git a/serf-proto/src/tags.rs b/serf-proto/src/tags.rs index 03822dd..dd18344 100644 --- a/serf-proto/src/tags.rs +++ b/serf-proto/src/tags.rs @@ -1,7 +1,6 @@ use indexmap::IndexMap; use smol_str::SmolStr; - /// Tags of a node #[derive( Debug, @@ -16,7 +15,10 @@ use smol_str::SmolStr; #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", serde(transparent))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -pub struct Tags(#[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::arbitrary_indexmap))] IndexMap); +pub struct Tags( + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::arbitrary_indexmap))] + IndexMap, +); impl IntoIterator for Tags { type Item = (SmolStr, SmolStr); diff --git a/serf-proto/src/user_event.rs b/serf-proto/src/user_event.rs index 37eb2cd..af2d281 100644 --- a/serf-proto/src/user_event.rs +++ b/serf-proto/src/user_event.rs @@ -1,4 +1,4 @@ -use memberlist_proto::{bytes::Bytes, CheapClone, OneOrMore}; +use memberlist_proto::{CheapClone, OneOrMore, bytes::Bytes}; use smol_str::SmolStr; use super::LamportTime; @@ -106,4 +106,3 @@ impl CheapClone for UserEventMessage { } } } - diff --git a/serf-proto/src/version.rs b/serf-proto/src/version.rs index b8f6123..dbbff00 100644 --- a/serf-proto/src/version.rs +++ b/serf-proto/src/version.rs @@ -1,7 +1,9 @@ use memberlist_proto::{Data, DataRef, DecodeError, EncodeError, WireType}; /// Delegate version -#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display)] +#[derive( + Debug, Default, Copy, Clone, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display, +)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[non_exhaustive] @@ -34,7 +36,9 @@ impl From for u8 { } /// Protocol version -#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display)] +#[derive( + Debug, Default, Copy, Clone, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display, +)] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[non_exhaustive] From d39f02b67814f7885cac3183f3bda0fceb90b034 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 25 Feb 2025 23:41:23 +0800 Subject: [PATCH 03/39] WIP --- Cargo.toml | 1 + serf-core/Cargo.toml | 2 +- serf-core/src/lib.rs | 254 ++++++------- serf-proto/Cargo.toml | 3 +- serf-proto/src/arbitrary_impl.rs | 88 ++++- serf-proto/src/filter.rs | 177 ++++++++- serf-proto/src/filter/id_filter.rs | 63 ++++ serf-proto/src/filter/tag_filter.rs | 225 ++++++++++++ serf-proto/src/join.rs | 131 ++++++- serf-proto/src/leave.rs | 172 ++++++++- serf-proto/src/lib.rs | 18 + serf-proto/src/member.rs | 291 ++++++++++++++- serf-proto/src/query.rs | 450 ++++++++++++++++++++--- serf-proto/src/query/response.rs | 310 ++++++++++++++++ serf-proto/src/tags.rs | 320 ++++++++++++++++ serf-proto/src/user_event.rs | 207 +++++++---- serf-proto/src/user_event/message.rs | 270 ++++++++++++++ serf-proto/src/user_event/user_events.rs | 182 +++++++++ serf-proto/src/version.rs | 2 - serf/src/lib.rs | 50 +-- 20 files changed, 2922 insertions(+), 294 deletions(-) create mode 100644 serf-proto/src/filter/id_filter.rs create mode 100644 serf-proto/src/filter/tag_filter.rs create mode 100644 serf-proto/src/query/response.rs create mode 100644 serf-proto/src/user_event/message.rs create mode 100644 serf-proto/src/user_event/user_events.rs diff --git a/Cargo.toml b/Cargo.toml index 28ba986..37e08ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ indexmap = "2" # memberlist = { version = "0.3", default-features = false } thiserror = { version = "2", default-features = false } viewit = "0.1.5" +regex = "1" smol_str = "0.3" smallvec = "1" rand = "0.9" diff --git a/serf-core/Cargo.toml b/serf-core/Cargo.toml index 584cca9..f6f668c 100644 --- a/serf-core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -44,7 +44,7 @@ once_cell = "1" # remove this dependency when [feature(lazy_cell)] is stabilized parking_lot = { version = "0.12", features = ["send_guard"] } pin-project = "1" rand.workspace = true -regex = "1" +regex.workspace = true scopeguard = "1" smol_str.workspace = true smallvec.workspace = true diff --git a/serf-core/src/lib.rs b/serf-core/src/lib.rs index cc9cba2..c38d3f1 100644 --- a/serf-core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -1,127 +1,127 @@ -#![doc = include_str!("../../README.md")] -#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] -#![forbid(unsafe_code)] -#![deny(warnings, missing_docs)] -#![allow(clippy::type_complexity)] -#![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(docsrs, allow(unused_attributes))] - -pub(crate) mod broadcast; - -mod coalesce; - -/// Coordinate. -pub mod coordinate; - -/// Events for [`Serf`] -pub mod event; - -/// Errors for `serf`. -pub mod error; - -/// Delegate traits and its implementations. -pub mod delegate; - -mod options; -pub use options::*; - -/// The types used in `serf`. -pub mod types; - -/// Secret key management. -#[cfg(feature = "encryption")] -#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] -pub mod key_manager; - -mod serf; -pub use serf::*; - -mod snapshot; -pub use snapshot::*; - -fn invalid_data_io_error(e: E) -> std::io::Error { - std::io::Error::new(std::io::ErrorKind::InvalidData, e) -} - -/// All unit test fns are exported in the `tests` module. -/// This module is used for users want to use other async runtime, -/// and want to use the test if memberlist also works with their runtime. -#[cfg(feature = "test")] -#[cfg_attr(docsrs, doc(cfg(feature = "test")))] -pub mod tests { - pub use memberlist_core::tests::{AnyError, next_socket_addr_v4, next_socket_addr_v6}; - pub use paste; - - pub use super::serf::base::tests::{serf::*, *}; - - /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) - #[cfg(any(feature = "test", test))] - #[cfg_attr(docsrs, doc(cfg(any(feature = "test", test))))] - #[macro_export] - macro_rules! unit_tests { - ($runtime:ty => $run:ident($($fn:ident), +$(,)?)) => { - $( - ::serf_core::tests::paste::paste! { - #[test] - fn [< test_ $fn >] () { - $run($fn::<$runtime>()); - } - } - )* - }; - } - - /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) - #[cfg(any(feature = "test", test))] - #[cfg_attr(docsrs, doc(cfg(any(feature = "test", test))))] - #[macro_export] - macro_rules! unit_tests_with_expr { - ($run:ident($( - $(#[$outer:meta])* - $fn:ident( $expr:expr ) - ), +$(,)?)) => { - $( - ::serf_core::tests::paste::paste! { - #[test] - $(#[$outer])* - fn [< test_ $fn >] () { - $run(async move { - $expr - }); - } - } - )* - }; - } - - /// Initialize the tracing for the unit tests. - pub fn initialize_tests_tracing() { - use std::sync::Once; - static TRACE: Once = Once::new(); - TRACE.call_once(|| { - let filter = std::env::var("RUSERF_TESTING_LOG") - .unwrap_or_else(|_| "serf_core=info,memberlist_core=debug".to_owned()); - memberlist_core::tracing::subscriber::set_global_default( - tracing_subscriber::fmt::fmt() - .without_time() - .with_line_number(true) - .with_env_filter(filter) - .with_file(false) - .with_target(true) - .with_ansi(true) - .finish(), - ) - .unwrap(); - }); - } - - /// Run the unit test with a given async runtime sequentially. - pub fn run(block_on: B, fut: F) - where - B: FnOnce(F) -> F::Output, - F: std::future::Future, - { - // initialize_tests_tracing(); - block_on(fut); - } -} +// #![doc = include_str!("../../README.md")] +// #![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] +// #![forbid(unsafe_code)] +// #![deny(warnings, missing_docs)] +// #![allow(clippy::type_complexity)] +// #![cfg_attr(docsrs, feature(doc_cfg))] +// #![cfg_attr(docsrs, allow(unused_attributes))] + +// pub(crate) mod broadcast; + +// mod coalesce; + +// /// Coordinate. +// pub mod coordinate; + +// /// Events for [`Serf`] +// pub mod event; + +// /// Errors for `serf`. +// pub mod error; + +// /// Delegate traits and its implementations. +// pub mod delegate; + +// mod options; +// pub use options::*; + +// /// The types used in `serf`. +// pub mod types; + +// /// Secret key management. +// #[cfg(feature = "encryption")] +// #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] +// pub mod key_manager; + +// mod serf; +// pub use serf::*; + +// mod snapshot; +// pub use snapshot::*; + +// fn invalid_data_io_error(e: E) -> std::io::Error { +// std::io::Error::new(std::io::ErrorKind::InvalidData, e) +// } + +// /// All unit test fns are exported in the `tests` module. +// /// This module is used for users want to use other async runtime, +// /// and want to use the test if memberlist also works with their runtime. +// #[cfg(feature = "test")] +// #[cfg_attr(docsrs, doc(cfg(feature = "test")))] +// pub mod tests { +// pub use memberlist_core::tests::{AnyError, next_socket_addr_v4, next_socket_addr_v6}; +// pub use paste; + +// pub use super::serf::base::tests::{serf::*, *}; + +// /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) +// #[cfg(any(feature = "test", test))] +// #[cfg_attr(docsrs, doc(cfg(any(feature = "test", test))))] +// #[macro_export] +// macro_rules! unit_tests { +// ($runtime:ty => $run:ident($($fn:ident), +$(,)?)) => { +// $( +// ::serf_core::tests::paste::paste! { +// #[test] +// fn [< test_ $fn >] () { +// $run($fn::<$runtime>()); +// } +// } +// )* +// }; +// } + +// /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) +// #[cfg(any(feature = "test", test))] +// #[cfg_attr(docsrs, doc(cfg(any(feature = "test", test))))] +// #[macro_export] +// macro_rules! unit_tests_with_expr { +// ($run:ident($( +// $(#[$outer:meta])* +// $fn:ident( $expr:expr ) +// ), +$(,)?)) => { +// $( +// ::serf_core::tests::paste::paste! { +// #[test] +// $(#[$outer])* +// fn [< test_ $fn >] () { +// $run(async move { +// $expr +// }); +// } +// } +// )* +// }; +// } + +// /// Initialize the tracing for the unit tests. +// pub fn initialize_tests_tracing() { +// use std::sync::Once; +// static TRACE: Once = Once::new(); +// TRACE.call_once(|| { +// let filter = std::env::var("RUSERF_TESTING_LOG") +// .unwrap_or_else(|_| "serf_core=info,memberlist_core=debug".to_owned()); +// memberlist_core::tracing::subscriber::set_global_default( +// tracing_subscriber::fmt::fmt() +// .without_time() +// .with_line_number(true) +// .with_env_filter(filter) +// .with_file(false) +// .with_target(true) +// .with_ansi(true) +// .finish(), +// ) +// .unwrap(); +// }); +// } + +// /// Run the unit test with a given async runtime sequentially. +// pub fn run(block_on: B, fut: F) +// where +// B: FnOnce(F) -> F::Output, +// F: std::future::Future, +// { +// // initialize_tests_tracing(); +// block_on(fut); +// } +// } diff --git a/serf-proto/Cargo.toml b/serf-proto/Cargo.toml index 61fc0de..cec8ff9 100644 --- a/serf-proto/Cargo.toml +++ b/serf-proto/Cargo.toml @@ -20,10 +20,11 @@ quickcheck = ["dep:quickcheck", "memberlist-proto/quickcheck"] bitflags = "2" byteorder.workspace = true bytemuck = { version = "1", features = ["derive"] } -derive_more = { workspace = true, features = ["is_variant", "display"] } +derive_more = { workspace = true, features = ["is_variant", "display", "unwrap", "try_unwrap"] } futures = { workspace = true, optional = true, features = ["alloc"] } indexmap.workspace = true memberlist-proto.workspace = true +regex.workspace = true smol_str.workspace = true thiserror.workspace = true viewit.workspace = true diff --git a/serf-proto/src/arbitrary_impl.rs b/serf-proto/src/arbitrary_impl.rs index 2dabfaa..d50f558 100644 --- a/serf-proto/src/arbitrary_impl.rs +++ b/serf-proto/src/arbitrary_impl.rs @@ -3,6 +3,8 @@ use std::{ hash::Hash, }; +use crate::TagFilter; + use super::Filter; use arbitrary::{Arbitrary, Unstructured}; use indexmap::{IndexMap, IndexSet}; @@ -46,10 +48,70 @@ where Ok(if kind { Filter::Id(into::, TinyVec<_>>(u)?) } else { - Filter::Tag { - tag: u.arbitrary()?, - expr: u.arbitrary()?, - } + Filter::Tag( + TagFilter::new() + .with_tag(u.arbitrary()?) + .maybe_expr(if u.arbitrary()? { + let complexity = u.int_in_range(1..=5)?; + let mut patterns = Vec::new(); + + // Basic character classes and quantifiers + let character_classes = vec![ + r"\d", + r"\w", + r"\s", + r"[a-z]", + r"[A-Z]", + r"[0-9]", + r"[a-zA-Z]", + r"[a-zA-Z0-9]", + r".", + ]; + + let quantifiers = vec!["", "*", "+", "?", "{1,3}", "{2,5}"]; + + // Add more complex patterns for higher complexity + let mut extended_classes = character_classes.clone(); + if complexity > 1 { + extended_classes.extend(vec![r"[^a-z]", r"[^0-9]", r"\D", r"\W", r"\S"]); + } + + if complexity > 2 { + // Add a group with random content + let char_class = u.choose(&extended_classes)?; + let quantifier = u.choose(&quantifiers)?; + patterns.push(format!("({}{})", char_class, quantifier)); + } + + // Generate random pattern parts + for _ in 0..complexity { + let char_class = u.choose(&extended_classes)?; + let quantifier = u.choose(&quantifiers)?; + patterns.push(format!("{}{}", char_class, quantifier)); + } + + // Maybe add anchors for higher complexity + if complexity > 2 && u.ratio(7, 10)? { + if u.arbitrary()? { + patterns.insert(0, "^".to_string()); + } + if u.arbitrary()? { + patterns.push("$".to_string()); + } + } + + // Add alternation for even higher complexity + if complexity > 3 && u.ratio(6, 10)? { + let char_class = u.choose(&extended_classes)?; + let quantifier = u.choose(&quantifiers)?; + patterns.push(format!("|{}{}", char_class, quantifier)); + } + + Some(patterns.join("").try_into().unwrap()) + } else { + None + }), + ) }) } } @@ -64,6 +126,24 @@ impl<'a> Arbitrary<'a> for super::QueryFlag { } } +impl<'a> Arbitrary<'a> for super::ProtocolVersion { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + u.arbitrary::().map(Into::into) + } +} + +impl<'a> Arbitrary<'a> for super::DelegateVersion { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + u.arbitrary::().map(Into::into) + } +} + +impl<'a> Arbitrary<'a> for super::MemberStatus { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + u.arbitrary::().map(Into::into) + } +} + #[cfg(feature = "encryption")] impl<'a, I> Arbitrary<'a> for super::KeyResponse where diff --git a/serf-proto/src/filter.rs b/serf-proto/src/filter.rs index 1715bfb..1da01b3 100644 --- a/serf-proto/src/filter.rs +++ b/serf-proto/src/filter.rs @@ -1,5 +1,13 @@ -use memberlist_proto::TinyVec; -use smol_str::SmolStr; +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, TinyVec, WireType, + utils::{merge, split}, +}; + +pub use id_filter::*; +pub use tag_filter::*; + +mod id_filter; +mod tag_filter; /// The type of filter #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, derive_more::IsVariant, derive_more::Display)] @@ -52,19 +60,14 @@ impl From for u8 { /// Used with a queryFilter to specify the type of /// filter we are sending -#[derive(Debug, Clone, Eq, PartialEq, derive_more::IsVariant)] +#[derive(Debug, Clone, PartialEq, Eq, derive_more::IsVariant)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] pub enum Filter { /// Filter by node ids Id(TinyVec), /// Filter by tag - Tag { - /// The tag to filter by - tag: SmolStr, - /// The expression to filter by - expr: SmolStr, - }, + Tag(TagFilter), } impl Filter { @@ -77,3 +80,159 @@ impl Filter { } } } + +const FILTER_ID_TAG: u8 = 1; +const FILTER_TAG_TAG: u8 = 2; + +/// The reference type to [`Filter`] +pub enum FilterRef<'a, I> { + /// Filter by node ids + Id(IdDecoder<'a, I>), + /// Filter by tag + Tag(TagFilterRef<'a>), +} + +impl Clone for FilterRef<'_, I> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for FilterRef<'_, I> {} + +impl core::fmt::Debug for FilterRef<'_, I> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Id(id) => f.debug_tuple("FilterRef::Id").field(id).finish(), + Self::Tag(t) => f.debug_tuple("FilterRef::Tag").field(t).finish(), + } + } +} + +impl<'a, I> DataRef<'a, Filter> for FilterRef<'a, I> +where + I: Data, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let buf_len = buf.len(); + if buf_len < 1 { + return Err(DecodeError::buffer_underflow()); + } + + let mut offset = 0; + + match buf[0] { + val if val == Filter::::id_byte() => { + offset += 1; + Ok((offset, Self::Id(IdDecoder::new(&buf[offset..])))) + } + val if val == Filter::::tag_byte() => { + offset += 1; + let (read, tag) = + >::decode_length_delimited(&buf[offset..])?; + offset += read; + Ok((offset, Self::Tag(tag))) + } + b => { + let (wire_type, tag) = split(b); + WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + + Err(DecodeError::unknown_tag("Filter", tag)) + } + } + } +} + +impl Filter +where + I: Data, +{ + const fn id_byte() -> u8 { + merge(I::WIRE_TYPE, FILTER_ID_TAG) + } + + const fn tag_byte() -> u8 { + merge(WireType::LengthDelimited, FILTER_TAG_TAG) + } +} + +impl Data for Filter +where + I: Data, +{ + type Ref<'a> = FilterRef<'a, I>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + match val { + FilterRef::Id(decoder) => decoder + .map(|res| res.and_then(I::from_ref)) + .collect::>() + .map(Self::Id), + FilterRef::Tag(tag) => TagFilter::from_ref(tag).map(Self::Tag), + } + } + + fn encoded_len(&self) -> usize { + 1usize + + match self { + Filter::Id(ids) => ids + .iter() + .map(|id| id.encoded_len_with_length_delimited()) + .sum::(), + Filter::Tag(tag) => 1 + tag.encoded_len_with_length_delimited(), + } + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let buf_len = buf.len(); + if buf_len < 1 { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + let mut offset = 0; + + match self { + Filter::Id(ids) => { + buf[offset] = Self::id_byte(); + offset += 1; + + ids + .iter() + .try_fold(&mut offset, |offset, id| { + *offset += id.encode_length_delimited(&mut buf[*offset..])?; + + Ok(offset) + }) + .map_err(|e: EncodeError| e.update(self.encoded_len(), buf_len))?; + + Ok(offset) + } + Filter::Tag(tag) => { + buf[offset] = Self::tag_byte(); + offset += 1; + + if offset > buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + offset += tag + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + Ok(offset) + } + } + } +} diff --git a/serf-proto/src/filter/id_filter.rs b/serf-proto/src/filter/id_filter.rs new file mode 100644 index 0000000..0de00a2 --- /dev/null +++ b/serf-proto/src/filter/id_filter.rs @@ -0,0 +1,63 @@ +use memberlist_proto::{Data, DataRef, DecodeError}; + +/// The decoder for ids +pub struct IdDecoder<'a, I> { + src: &'a [u8], + len: usize, + offset: usize, + has_err: bool, + _phantom: std::marker::PhantomData, +} + +impl Clone for IdDecoder<'_, I> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for IdDecoder<'_, I> {} + +impl core::fmt::Debug for IdDecoder<'_, I> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("IdDecoder") + .field("src", &self.src) + .field("offset", &self.offset) + .finish() + } +} + +impl<'a, I> IdDecoder<'a, I> { + pub(super) const fn new(src: &'a [u8]) -> Self { + Self { + src, + offset: 0, + len: src.len(), + has_err: false, + _phantom: std::marker::PhantomData, + } + } +} + +impl<'a, I> Iterator for IdDecoder<'a, I> +where + I: Data, +{ + type Item = Result, DecodeError>; + + fn next(&mut self) -> Option { + if self.has_err || self.offset >= self.len { + return None; + } + + Some( + as DataRef<'_, I>>::decode_length_delimited(&self.src[self.offset..]) + .inspect_err(|_| { + self.has_err = true; + }) + .map(|(read, value)| { + self.offset += read; + value + }), + ) + } +} diff --git a/serf-proto/src/filter/tag_filter.rs b/serf-proto/src/filter/tag_filter.rs new file mode 100644 index 0000000..fb241f9 --- /dev/null +++ b/serf-proto/src/filter/tag_filter.rs @@ -0,0 +1,225 @@ +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, WireType, + utils::{merge, skip, split}, +}; +use regex::Regex; +use smol_str::SmolStr; + +const TAG_TAG: u8 = 1; +const EXPR_TAG: u8 = 2; +const TAG_BYTE: u8 = merge(WireType::LengthDelimited, TAG_TAG); +const EXPR_BYTE: u8 = merge(WireType::LengthDelimited, EXPR_TAG); + +#[viewit::viewit(vis_all = "", getters(vis_all = "pub"), setters(skip))] +/// The reference type of the [`TagFilter`] type +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct TagFilterRef<'a> { + #[viewit(getter(const, attrs(doc = "Returns the tag")))] + tag: &'a str, + #[viewit(getter(const, attrs(doc = "Returns the expression")))] + expr: Option<&'a str>, +} + +impl<'a> DataRef<'a, TagFilter> for TagFilterRef<'a> { + fn decode(src: &'a [u8]) -> Result<(usize, Self), memberlist_proto::DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = src.len(); + let mut tag = None; + let mut expr = None; + + while offset < buf_len { + match src[offset] { + TAG_BYTE => { + if tag.is_some() { + return Err(DecodeError::duplicate_field("TagFilter", "tag", TAG_TAG)); + } + offset += 1; + + let (read, value) = <&str as DataRef<'_, SmolStr>>::decode(&src[offset..])?; + offset += read; + tag = Some(value); + } + EXPR_BYTE => { + if expr.is_some() { + return Err(DecodeError::duplicate_field("TagFilter", "expr", EXPR_TAG)); + } + offset += 1; + + let (read, value) = <&str as DataRef<'_, SmolStr>>::decode(&src[offset..])?; + offset += read; + expr = Some(value); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &src[offset..])?; + } + } + } + + Ok(( + offset, + Self { + tag: tag.unwrap_or(""), + expr: expr.and_then(|expr| if expr.is_empty() { None } else { Some(expr) }), + }, + )) + } +} + +/// The tag filter +#[viewit::viewit( + vis_all = "", + getters(vis_all = "pub", style = "ref"), + setters(vis_all = "pub", prefix = "with") +)] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct TagFilter { + #[viewit( + getter(const, attrs(doc = "Returns the tag")), + setter(attrs(doc = "Sets the tag (Builder pattern)")) + )] + tag: SmolStr, + #[cfg_attr(feature = "serde", serde(with = "serde_regex"))] + #[viewit( + getter( + const, + attrs(doc = "Returns the expression"), + result(converter(fn = "Option::as_ref"), type = "Option<&Regex>"), + ), + setter( + rename = "maybe_expr", + attrs(doc = "Sets the expression (Builder pattern)") + ) + )] + expr: Option, +} + +impl Default for TagFilter { + fn default() -> Self { + Self::new() + } +} + +impl TagFilter { + /// Creates a new tag filter + #[inline] + pub const fn new() -> Self { + Self { + tag: SmolStr::new_inline(""), + expr: None, + } + } +} + +impl PartialEq for TagFilter { + fn eq(&self, other: &Self) -> bool { + self.tag == other.tag + && self.expr.as_ref().map(|re| re.as_str()) == other.expr.as_ref().map(|re| re.as_str()) + } +} + +impl Eq for TagFilter {} + +impl Data for TagFilter { + type Ref<'a> = TagFilterRef<'a>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(Self { + tag: SmolStr::from(val.tag), + expr: val + .expr + .map(|expr| Regex::new(expr).map_err(|e| DecodeError::custom(e.to_string()))) + .transpose()?, + }) + } + + fn encoded_len(&self) -> usize { + 1 + self.tag.encoded_len_with_length_delimited() + + match self.expr.as_ref() { + Some(re) => { + let re = re.as_str(); + let len = re.len(); + 1 + (len as u32).encoded_len() + len + } + None => 0, + } + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let buf_len = buf.len(); + let mut offset = 0; + + if buf_len <= offset { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset] = TAG_BYTE; + offset += 1; + offset += self + .tag + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + if let Some(re) = self.expr.as_ref() { + let re = re.as_str(); + let len = re.len(); + if buf_len < offset + 1 + (len as u32).encoded_len() + len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset] = EXPR_BYTE; + offset += 1; + offset += (len as u32).encode(&mut buf[offset..])?; + buf[offset..offset + len].copy_from_slice(re.as_bytes()); + offset += len; + } + + #[cfg(debug_assertions)] + super::super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } +} + +#[cfg(feature = "serde")] +mod serde_regex { + use regex::Regex; + use serde::{de, ser}; + + pub fn serialize(value: &Option, serializer: S) -> Result + where + S: ser::Serializer, + { + match value { + Some(re) => serializer.serialize_str(re.as_str()), + None => serializer.serialize_none(), + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: de::Deserializer<'de>, + { + let s = as de::Deserialize<'_>>::deserialize(deserializer)?; + match s { + Some(s) => s.try_into().map(Some).map_err(de::Error::custom), + None => Ok(None), + } + } +} diff --git a/serf-proto/src/join.rs b/serf-proto/src/join.rs index 9cce02d..906ded5 100644 --- a/serf-proto/src/join.rs +++ b/serf-proto/src/join.rs @@ -1,9 +1,18 @@ +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, WireType, + utils::{merge, skip, split}, +}; + use super::LamportTime; +const LTIME_TAG: u8 = 1; +const LTIME_BYTE: u8 = merge(WireType::Varint, LTIME_TAG); +const ID_TAG: u8 = 2; + /// The message broadcasted after we join to /// associated the node with a lamport clock #[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct JoinMessage { @@ -43,4 +52,124 @@ impl JoinMessage { self.id = id; self } + + const fn id_byte() -> u8 + where + I: Data, + { + merge(I::WIRE_TYPE, ID_TAG) + } +} + +impl<'a, I> DataRef<'a, JoinMessage> for JoinMessage> +where + I: Data, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let mut ltime = None; + let mut id = None; + + while offset < buf.len() { + match buf[offset] { + LTIME_BYTE => { + if ltime.is_some() { + return Err(DecodeError::duplicate_field( + "JoinMessage", + "ltime", + LTIME_TAG, + )); + } + offset += 1; + + let (read, value) = >::decode(&buf[offset..])?; + offset += read; + ltime = Some(value); + } + ID_TAG => { + if id.is_some() { + return Err(DecodeError::duplicate_field("JoinMessage", "id", ID_TAG)); + } + offset += 1; + + let (read, value) = + as DataRef<'_, I>>::decode_length_delimited(&buf[offset..])?; + offset += read; + id = Some(value); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + ltime: ltime.ok_or_else(|| DecodeError::missing_field("JoinMessage", "ltime"))?, + id: id.ok_or_else(|| DecodeError::missing_field("JoinMessage", "id"))?, + }, + )) + } +} + +impl Data for JoinMessage +where + I: Data, +{ + type Ref<'a> = JoinMessage>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + I::from_ref(val.id).map(|id| Self { + ltime: val.ltime, + id, + }) + } + + fn encoded_len(&self) -> usize { + 1 + self.ltime.encoded_len() + 1 + self.id.encoded_len_with_length_delimited() + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let buf_len = buf.len(); + let mut offset = 0; + + if buf_len < 1 { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset] = LTIME_BYTE; + offset += 1; + offset += self.ltime.encode(buf)?; + + if buf_len <= offset { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset] = Self::id_byte(); + offset += 1; + + offset += self.id.encode_length_delimited(&mut buf[offset..])?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } } diff --git a/serf-proto/src/leave.rs b/serf-proto/src/leave.rs index 65df594..8d0f0b1 100644 --- a/serf-proto/src/leave.rs +++ b/serf-proto/src/leave.rs @@ -1,9 +1,21 @@ +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, WireType, + utils::{merge, skip, split}, +}; + use super::LamportTime; +const LTIME_TAG: u8 = 1; +const PRUNE_TAG: u8 = 2; +const ID_TAG: u8 = 3; + +const LTIME_BYTE: u8 = merge(WireType::Varint, LTIME_TAG); +const PRUNE_BYTE: u8 = merge(WireType::Byte, PRUNE_TAG); + /// The message broadcasted to signal the intentional to /// leave. #[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct LeaveMessage { @@ -30,3 +42,161 @@ pub struct LeaveMessage { )] prune: bool, } + +impl LeaveMessage { + const fn id_byte() -> u8 + where + I: Data, + { + merge(I::WIRE_TYPE, ID_TAG) + } +} + +impl<'a, I> DataRef<'a, LeaveMessage> for LeaveMessage> +where + I: Data, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut ltime = None; + let mut id = None; + let mut prune = None; + + while offset < buf_len { + match buf[offset] { + LTIME_BYTE => { + if ltime.is_some() { + return Err(DecodeError::duplicate_field( + "LeaveMessage", + "ltime", + LTIME_TAG, + )); + } + offset += 1; + + let (read, value) = >::decode(&buf[offset..])?; + offset += read; + ltime = Some(value); + } + PRUNE_BYTE => { + if prune.is_some() { + return Err(DecodeError::duplicate_field( + "LeaveMessage", + "prune", + PRUNE_TAG, + )); + } + offset += 1; + + let (read, value) = >::decode(&buf[offset..])?; + offset += read; + prune = Some(value); + } + val if val == LeaveMessage::::id_byte() => { + offset += 1; + let (read, id_ref) = I::Ref::decode_length_delimited(&buf[offset..])?; + offset += read; + id = Some(id_ref); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + ltime: ltime.ok_or_else(|| DecodeError::missing_field("LeaveMessage", "ltime"))?, + id: id.ok_or_else(|| DecodeError::missing_field("LeaveMessage", "id"))?, + prune: prune.unwrap_or_default(), + }, + )) + } +} + +impl Data for LeaveMessage +where + I: Data, +{ + type Ref<'a> = LeaveMessage>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + I::from_ref(val.id).map(|id| Self { + ltime: val.ltime, + id, + prune: val.prune, + }) + } + + fn encoded_len(&self) -> usize { + 1 + self.ltime.encoded_len() + + if self.prune { 1 + 1 } else { 0 } + + 1 + + self.id.encoded_len_with_length_delimited() + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + let buf_len = buf.len(); + + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset] = LTIME_BYTE; + offset += 1; + offset += self + .ltime + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + if self.prune { + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset] = PRUNE_BYTE; + offset += 1; + offset += ::encode(&true, &mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + } + + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset] = Self::id_byte(); + offset += 1; + offset += self + .id + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } +} diff --git a/serf-proto/src/lib.rs b/serf-proto/src/lib.rs index 440848f..8321c73 100644 --- a/serf-proto/src/lib.rs +++ b/serf-proto/src/lib.rs @@ -52,3 +52,21 @@ mod key; #[cfg(feature = "encryption")] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] pub use key::*; + +#[cfg(debug_assertions)] +#[inline] +fn debug_assert_write_eq(actual: usize, expected: usize) { + debug_assert_eq!( + actual, expected, + "expect writting {expected} bytes, but actual write {actual} bytes" + ); +} + +// #[cfg(debug_assertions)] +// #[inline] +// fn debug_assert_read_eq(actual: usize, expected: usize) { +// debug_assert_eq!( +// actual, expected, +// "expect reading {expected} bytes, but actual read {actual} bytes" +// ); +// } diff --git a/serf-proto/src/member.rs b/serf-proto/src/member.rs index 8b603f1..a3d12f4 100644 --- a/serf-proto/src/member.rs +++ b/serf-proto/src/member.rs @@ -1,6 +1,11 @@ use std::sync::Arc; -use memberlist_proto::CheapClone; +use memberlist_proto::{ + CheapClone, Data, DataRef, EncodeError, WireType, + utils::{merge, skip, split}, +}; + +use crate::TagsRef; use super::{ DelegateVersion, MemberlistDelegateVersion, MemberlistProtocolVersion, Node, ProtocolVersion, @@ -19,7 +24,6 @@ const MEMBER_STATUS_FAILED: u8 = 4; )] #[repr(u8)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[non_exhaustive] pub enum MemberStatus { /// None status @@ -185,3 +189,286 @@ impl CheapClone for Member { } } } + +const NODE_TAG: u8 = 1; +const TAGS_TAG: u8 = 2; +const STATUS_TAG: u8 = 3; +const MEMBERLIST_PROTOCOL_VERSION_TAG: u8 = 4; +const MEMBERLIST_DELEGATE_VERSION_TAG: u8 = 5; +const PROTOCOL_VERSION_TAG: u8 = 6; +const DELEGATE_VERSION_TAG: u8 = 7; + +const NODE_BYTE: u8 = merge(WireType::LengthDelimited, NODE_TAG); +const TAGS_BYTE: u8 = merge(WireType::LengthDelimited, TAGS_TAG); +const STATUS_BYTE: u8 = merge(WireType::Byte, STATUS_TAG); +const MEMBERLIST_PROTOCOL_VERSION_BYTE: u8 = merge(WireType::Byte, MEMBERLIST_PROTOCOL_VERSION_TAG); +const MEMBERLIST_DELEGATE_VERSION_BYTE: u8 = merge(WireType::Byte, MEMBERLIST_DELEGATE_VERSION_TAG); +const PROTOCOL_VERSION_BYTE: u8 = merge(WireType::Byte, PROTOCOL_VERSION_TAG); +const DELEGATE_VERSION_BYTE: u8 = merge(WireType::Byte, DELEGATE_VERSION_TAG); + +/// A reference type to [`Member`] +#[viewit::viewit(vis_all = "", getters(vis_all = "pub"), setters(skip))] +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct MemberRef<'a, I, A> { + /// The node + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the node")))] + node: Node, + /// The tags + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the tags")))] + tags: TagsRef<'a>, + /// The status + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the status")))] + status: MemberStatus, + /// The memberlist protocol version + #[viewit(getter(const, attrs(doc = "Returns the memberlist protocol version")))] + memberlist_protocol_version: MemberlistProtocolVersion, + /// The memberlist delegate version + #[viewit(getter(const, attrs(doc = "Returns the memberlist delegate version")))] + memberlist_delegate_version: MemberlistDelegateVersion, + /// The serf protocol version + #[viewit(getter(const, attrs(doc = "Returns the serf protocol version")))] + protocol_version: ProtocolVersion, + /// The serf delegate version + #[viewit(getter(const, attrs(doc = "Returns the serf delegate version")))] + delegate_version: DelegateVersion, +} + +impl<'a, I, A> DataRef<'a, Member> for MemberRef<'a, I::Ref<'a>, A::Ref<'a>> +where + I: Data, + A: Data, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), memberlist_proto::DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut node = None; + let mut tags = None; + let mut status = None; + let mut memberlist_protocol_version = None; + let mut memberlist_delegate_version = None; + let mut protocol_version = None; + let mut delegate_version = None; + + while offset < buf_len { + match buf[offset] { + NODE_BYTE => { + if node.is_some() { + return Err(memberlist_proto::DecodeError::duplicate_field( + "Member", "node", NODE_TAG, + )); + } + offset += 1; + let (size, val) = + , A::Ref<'_>> as DataRef<'_, Node>>::decode_length_delimited( + &buf[offset..], + )?; + node = Some(val); + offset += size; + } + TAGS_BYTE => { + if tags.is_some() { + return Err(memberlist_proto::DecodeError::duplicate_field( + "Member", "tags", TAGS_TAG, + )); + } + offset += 1; + let (size, val) = + as DataRef<'_, Tags>>::decode_length_delimited(&buf[offset..])?; + tags = Some(val); + offset += size; + } + STATUS_BYTE => { + if status.is_some() { + return Err(memberlist_proto::DecodeError::duplicate_field( + "Member", "status", STATUS_TAG, + )); + } + offset += 1; + status = Some(buf[offset].into()); + offset += 1; + } + MEMBERLIST_PROTOCOL_VERSION_BYTE => { + if memberlist_protocol_version.is_some() { + return Err(memberlist_proto::DecodeError::duplicate_field( + "Member", + "memberlist_protocol_version", + MEMBERLIST_PROTOCOL_VERSION_TAG, + )); + } + offset += 1; + memberlist_protocol_version = Some(buf[offset].into()); + offset += 1; + } + MEMBERLIST_DELEGATE_VERSION_BYTE => { + if memberlist_delegate_version.is_some() { + return Err(memberlist_proto::DecodeError::duplicate_field( + "Member", + "memberlist_delegate_version", + MEMBERLIST_DELEGATE_VERSION_TAG, + )); + } + offset += 1; + memberlist_delegate_version = Some(buf[offset].into()); + offset += 1; + } + PROTOCOL_VERSION_BYTE => { + if protocol_version.is_some() { + return Err(memberlist_proto::DecodeError::duplicate_field( + "Member", + "protocol_version", + PROTOCOL_VERSION_TAG, + )); + } + offset += 1; + protocol_version = Some(buf[offset].into()); + offset += 1; + } + DELEGATE_VERSION_BYTE => { + if delegate_version.is_some() { + return Err(memberlist_proto::DecodeError::duplicate_field( + "Member", + "delegate_version", + DELEGATE_VERSION_TAG, + )); + } + offset += 1; + delegate_version = Some(buf[offset].into()); + offset += 1; + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type) + .map_err(memberlist_proto::DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + node: node.ok_or_else(|| memberlist_proto::DecodeError::missing_field("Member", "node"))?, + tags: tags.ok_or_else(|| memberlist_proto::DecodeError::missing_field("Member", "tags"))?, + status: status + .ok_or_else(|| memberlist_proto::DecodeError::missing_field("Member", "status"))?, + memberlist_protocol_version: memberlist_protocol_version.ok_or_else(|| { + memberlist_proto::DecodeError::missing_field("Member", "memberlist_protocol_version") + })?, + memberlist_delegate_version: memberlist_delegate_version.ok_or_else(|| { + memberlist_proto::DecodeError::missing_field("Member", "memberlist_delegate_version") + })?, + protocol_version: protocol_version.ok_or_else(|| { + memberlist_proto::DecodeError::missing_field("Member", "protocol_version") + })?, + delegate_version: delegate_version.ok_or_else(|| { + memberlist_proto::DecodeError::missing_field("Member", "delegate_version") + })?, + }, + )) + } +} + +impl Data for Member +where + I: Data, + A: Data, +{ + type Ref<'a> = MemberRef<'a, I::Ref<'a>, A::Ref<'a>>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(Self { + node: Node::from_ref(val.node)?, + tags: Tags::from_ref(val.tags)?.into(), + status: val.status, + memberlist_protocol_version: val.memberlist_protocol_version, + memberlist_delegate_version: val.memberlist_delegate_version, + protocol_version: val.protocol_version, + delegate_version: val.delegate_version, + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + len += 1 + self.node.encoded_len_with_length_delimited(); + len += 1 + self.tags.encoded_len_with_length_delimited(); + len += 1 + 1; // status + len += 1 + 1; // memberlist_protocol_version + len += 1 + 1; // memberlist_delegate_version + len += 1 + 1; // protocol_version + len += 1 + 1; // delegate_version + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + } + }; + } + + let buf_len = buf.len(); + let mut offset = 0; + bail!(self(offset, buf_len)); + + buf[offset] = NODE_BYTE; + offset += 1; + offset += self.node.encode_length_delimited(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = TAGS_BYTE; + offset += 1; + offset += self.tags.encode_length_delimited(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = STATUS_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.status.into(); + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = MEMBERLIST_PROTOCOL_VERSION_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.memberlist_protocol_version.into(); + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = MEMBERLIST_DELEGATE_VERSION_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.memberlist_delegate_version.into(); + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = PROTOCOL_VERSION_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.protocol_version.into(); + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = DELEGATE_VERSION_BYTE; + offset += 1; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } +} diff --git a/serf-proto/src/query.rs b/serf-proto/src/query.rs index 55a1067..c4a2d6d 100644 --- a/serf-proto/src/query.rs +++ b/serf-proto/src/query.rs @@ -2,10 +2,18 @@ use smol_str::SmolStr; use std::time::Duration; -use memberlist_proto::{Node, TinyVec, bytes::Bytes}; +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, Node, RepeatedDecoder, TinyVec, WireType, + bytes::Bytes, + utils::{merge, skip, split}, +}; use super::{Filter, LamportTime}; +pub use response::*; + +mod response; + bitflags::bitflags! { /// Flags for query message #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] @@ -20,6 +28,26 @@ bitflags::bitflags! { } } +const LTIME_TAG: u8 = 1; +const ID_TAG: u8 = 2; +const FROM_TAG: u8 = 3; +const FILTERS_TAG: u8 = 4; +const FLAGS_TAG: u8 = 5; +const RELAY_FACTOR_TAG: u8 = 6; +const TIMEOUT_TAG: u8 = 7; +const NAME_TAG: u8 = 8; +const PAYLOAD_TAG: u8 = 9; + +const LTIME_BYTE: u8 = merge(WireType::Varint, LTIME_TAG); +const ID_BYTE: u8 = merge(WireType::Varint, ID_TAG); +const FROM_BYTE: u8 = merge(WireType::LengthDelimited, FROM_TAG); +const FILTERS_BYTE: u8 = merge(WireType::LengthDelimited, FILTERS_TAG); +const FLAGS_BYTE: u8 = merge(WireType::Varint, FLAGS_TAG); +const RELAY_FACTOR_BYTE: u8 = merge(WireType::Varint, RELAY_FACTOR_TAG); +const TIMEOUT_BYTE: u8 = merge(WireType::Varint, TIMEOUT_TAG); +const NAME_BYTE: u8 = merge(WireType::LengthDelimited, NAME_TAG); +const PAYLOAD_BYTE: u8 = merge(WireType::LengthDelimited, PAYLOAD_TAG); + /// Query message #[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] #[derive(Debug, Clone, Eq, PartialEq)] @@ -106,58 +134,392 @@ impl QueryMessage { } } -/// Query response message -#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -pub struct QueryResponseMessage { +/// The reference type of [`QueryMessage`] +#[viewit::viewit(vis_all = "", getters(vis_all = "pub", style = "ref"), setters(skip))] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct QueryMessageRef<'a, I, A> { /// Event lamport time - #[viewit( - getter(const, attrs(doc = "Returns the lamport time for this message")), - setter( - const, - attrs(doc = "Sets the lamport time for this message (Builder pattern)") - ) - )] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the event lamport time")))] ltime: LamportTime, - /// query id - #[viewit( - getter(const, attrs(doc = "Returns the query id")), - setter(attrs(doc = "Sets the query id (Builder pattern)")) - )] + /// query id, randomly generated + #[viewit(getter(const, style = "move", attrs(doc = "Returns the query id")))] id: u32, - /// node - #[viewit( - getter(const, attrs(doc = "Returns the from node")), - setter(attrs(doc = "Sets the from node (Builder pattern)")) - )] + /// source node + #[viewit(getter(const, attrs(doc = "Returns the from node")))] from: Node, + /// Potential query filters + #[viewit(getter(const, attrs(doc = "Returns the potential query filters")))] + filters: RepeatedDecoder<'a>, /// Used to provide various flags - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the flags")), - setter(attrs(doc = "Sets the flags (Builder pattern)")) - )] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the flags")))] flags: QueryFlag, - /// Optional response payload - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the payload")), - setter(attrs(doc = "Sets the payload (Builder pattern)")) - )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] - payload: Bytes, + /// Used to set the number of duplicate relayed responses + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns the number of duplicate relayed responses") + ))] + relay_factor: u8, + /// Maximum time between delivery and response + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns the maximum time between delivery and response") + ))] + timeout: Duration, + /// Query nqme + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the name of the query")))] + name: &'a str, + /// Query payload + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the payload")))] + payload: &'a [u8], } -impl QueryResponseMessage { - /// Checks if the ack flag is set - #[inline] - pub fn ack(&self) -> bool { - self.flags.contains(QueryFlag::ACK) +impl<'a, I, A> DataRef<'a, QueryMessage> for QueryMessageRef<'a, I::Ref<'a>, A::Ref<'a>> +where + I: Data, + A: Data, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut ltime = None; + let mut id = None; + let mut from = None; + let mut filters_offsets = None; + let mut num_filters = 0; + let mut flags = None; + let mut relay_factor = None; + let mut timeout = None; + let mut name = None; + let mut payload = None; + + while offset < buf_len { + match buf[offset] { + LTIME_BYTE => { + if ltime.is_some() { + return Err(DecodeError::duplicate_field( + "QueryMessage", + "ltime", + LTIME_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + ltime = Some(v); + } + ID_BYTE => { + if id.is_some() { + return Err(DecodeError::duplicate_field("QueryMessage", "id", ID_TAG)); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + id = Some(v); + } + FROM_BYTE => { + if from.is_some() { + return Err(DecodeError::duplicate_field( + "QueryMessage", + "from", + FROM_TAG, + )); + } + + offset += 1; + let (o, v) = + , A::Ref<'_>> as DataRef<'_, Node>>::decode(&buf[offset..])?; + offset += o; + from = Some(v); + } + FILTERS_BYTE => { + let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = filters_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + filters_offsets = Some((offset - 1, offset + readed)); + } + num_filters += 1; + offset += readed; + } + FLAGS_BYTE => { + if flags.is_some() { + return Err(DecodeError::duplicate_field( + "QueryMessage", + "flags", + FLAGS_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + flags = Some(QueryFlag::from_bits_truncate(v)); + } + RELAY_FACTOR_BYTE => { + if relay_factor.is_some() { + return Err(DecodeError::duplicate_field( + "QueryMessage", + "relay_factor", + RELAY_FACTOR_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + relay_factor = Some(v); + } + TIMEOUT_BYTE => { + if timeout.is_some() { + return Err(DecodeError::duplicate_field( + "QueryMessage", + "timeout", + TIMEOUT_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + timeout = Some(v); + } + NAME_BYTE => { + if name.is_some() { + return Err(DecodeError::duplicate_field( + "QueryMessage", + "name", + NAME_TAG, + )); + } + + offset += 1; + let (o, v) = <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&buf[offset..])?; + offset += o; + name = Some(v); + } + PAYLOAD_BYTE => { + if payload.is_some() { + return Err(DecodeError::duplicate_field( + "QueryMessage", + "payload", + PAYLOAD_TAG, + )); + } + + offset += 1; + let (o, v) = <&[u8] as DataRef<'_, Bytes>>::decode_length_delimited(&buf[offset..])?; + offset += o; + payload = Some(v); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + let filters = + RepeatedDecoder::new(FILTERS_TAG, WireType::LengthDelimited, buf).with_nums(num_filters); + + Ok(( + offset, + Self { + ltime: ltime.ok_or_else(|| DecodeError::missing_field("QueryMessage", "ltime"))?, + id: id.ok_or_else(|| DecodeError::missing_field("QueryMessage", "id"))?, + from: from.ok_or_else(|| DecodeError::missing_field("QueryMessage", "from"))?, + filters: if let Some((start, end)) = filters_offsets { + filters.with_offsets(start, end) + } else { + filters + }, + flags: flags.ok_or_else(|| DecodeError::missing_field("QueryMessage", "flags"))?, + relay_factor: relay_factor + .ok_or_else(|| DecodeError::missing_field("QueryMessage", "relay_factor"))?, + timeout: timeout.ok_or_else(|| DecodeError::missing_field("QueryMessage", "timeout"))?, + name: name.unwrap_or_default(), + payload: payload.unwrap_or_default(), + }, + )) } +} - /// Checks if the no broadcast flag is set - #[inline] - pub fn no_broadcast(&self) -> bool { - self.flags.contains(QueryFlag::NO_BROADCAST) +impl Data for QueryMessage +where + I: Data, + A: Data, +{ + type Ref<'a> = QueryMessageRef<'a, I::Ref<'a>, A::Ref<'a>>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + val + .filters + .iter::>() + .map(|res| res.and_then(Data::from_ref)) + .collect::, DecodeError>>() + .and_then(|filters| { + Ok(Self { + ltime: val.ltime, + id: val.id, + from: Node::from_ref(val.from)?, + filters, + flags: val.flags, + relay_factor: val.relay_factor, + timeout: val.timeout, + name: SmolStr::from(val.name), + payload: Bytes::copy_from_slice(val.payload), + }) + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + + len += 1 + self.ltime.encoded_len(); + len += 1 + self.id.encoded_len(); + len += 1 + self.from.encoded_len_with_length_delimited(); + len += self + .filters + .iter() + .map(|f| 1 + f.encoded_len_with_length_delimited()) + .sum::(); + + len += 1 + self.flags.bits().encoded_len(); + len += 1 + self.relay_factor.encoded_len(); + len += 1 + self.timeout.encoded_len(); + + let nlen = self.name.len(); + + if nlen != 0 { + len += 1 + self.name.encoded_len_with_length_delimited(); + } + + let plen = self.payload.len(); + + if plen != 0 { + len += 1 + self.payload.encoded_len_with_length_delimited(); + } + + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + } + }; + } + + let mut offset = 0; + let buf_len = buf.len(); + + bail!(self(offset, buf_len)); + buf[offset] = LTIME_BYTE; + offset += 1; + + offset += self + .ltime + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = ID_BYTE; + offset += 1; + + offset += self + .id + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = FROM_BYTE; + offset += 1; + + offset += self + .from + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + for filter in self.filters.iter() { + bail!(self(offset, buf_len)); + buf[offset] = FILTERS_BYTE; + offset += 1; + + offset += filter + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + } + + bail!(self(offset, buf_len)); + buf[offset] = FLAGS_BYTE; + offset += 1; + + offset += ::encode(&self.flags.bits(), &mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = RELAY_FACTOR_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.relay_factor; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = TIMEOUT_BYTE; + offset += 1; + + offset += self + .timeout + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + if !self.name.is_empty() { + bail!(self(offset, buf_len)); + buf[offset] = NAME_BYTE; + offset += 1; + + offset += self + .name + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + } + + if !self.payload.is_empty() { + bail!(self(offset, buf_len)); + buf[offset] = PAYLOAD_BYTE; + offset += 1; + + offset += self + .payload + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + } + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) } } diff --git a/serf-proto/src/query/response.rs b/serf-proto/src/query/response.rs new file mode 100644 index 0000000..e7e5c1d --- /dev/null +++ b/serf-proto/src/query/response.rs @@ -0,0 +1,310 @@ +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, Node, WireType, + bytes::Bytes, + utils::{merge, skip, split}, +}; + +use crate::LamportTime; + +use super::QueryFlag; + +const LTIME_TAG: u8 = 1; +const ID_TAG: u8 = 2; +const FROM_TAG: u8 = 3; +const FLAGS_TAG: u8 = 4; +const PAYLOAD_TAG: u8 = 5; + +const LTIME_BYTE: u8 = merge(WireType::Varint, LTIME_TAG); +const ID_BYTE: u8 = merge(WireType::Varint, ID_TAG); +const FROM_BYTE: u8 = merge(WireType::LengthDelimited, FROM_TAG); +const FLAGS_BYTE: u8 = merge(WireType::Varint, FLAGS_TAG); +const PAYLOAD_BYTE: u8 = merge(WireType::LengthDelimited, PAYLOAD_TAG); + +/// Query response message +#[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct QueryResponseMessage { + /// Event lamport time + #[viewit( + getter(const, attrs(doc = "Returns the lamport time for this message")), + setter( + const, + attrs(doc = "Sets the lamport time for this message (Builder pattern)") + ) + )] + ltime: LamportTime, + /// query id + #[viewit( + getter(const, attrs(doc = "Returns the query id")), + setter(attrs(doc = "Sets the query id (Builder pattern)")) + )] + id: u32, + /// node + #[viewit( + getter(const, attrs(doc = "Returns the from node")), + setter(attrs(doc = "Sets the from node (Builder pattern)")) + )] + from: Node, + /// Used to provide various flags + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the flags")), + setter(attrs(doc = "Sets the flags (Builder pattern)")) + )] + flags: QueryFlag, + /// Optional response payload + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the payload")), + setter(attrs(doc = "Sets the payload (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + payload: Bytes, +} + +impl QueryResponseMessage { + /// Checks if the ack flag is set + #[inline] + pub fn ack(&self) -> bool { + self.flags.contains(QueryFlag::ACK) + } + + /// Checks if the no broadcast flag is set + #[inline] + pub fn no_broadcast(&self) -> bool { + self.flags.contains(QueryFlag::NO_BROADCAST) + } +} + +/// The reference type to a query response message +#[viewit::viewit(vis_all = "", getters(vis_all = "pub", style = "ref"), setters(skip))] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct QueryResponseMessageRef<'a, I, A> { + /// Event lamport time + #[viewit(getter(const, attrs(doc = "Returns the lamport time for this message")))] + ltime: LamportTime, + /// query id + #[viewit(getter(const, attrs(doc = "Returns the query id")))] + id: u32, + /// node + #[viewit(getter(const, attrs(doc = "Returns the from node")))] + from: Node, + /// Used to provide various flags + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the flags")))] + flags: QueryFlag, + /// Optional response payload + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the payload")))] + payload: &'a [u8], +} + +impl<'a, I, A> DataRef<'a, QueryResponseMessage> + for QueryResponseMessageRef<'a, I::Ref<'a>, A::Ref<'a>> +where + I: Data, + A: Data, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut ltime = None; + let mut id = None; + let mut from = None; + let mut flags = None; + let mut payload = None; + + while offset < buf_len { + match buf[offset] { + LTIME_BYTE => { + if ltime.is_some() { + return Err(DecodeError::duplicate_field( + "QueryResponseMessage", + "ltime", + LTIME_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + ltime = Some(v); + } + ID_BYTE => { + if id.is_some() { + return Err(DecodeError::duplicate_field( + "QueryResponseMessage", + "id", + ID_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + id = Some(v); + } + FROM_BYTE => { + if from.is_some() { + return Err(DecodeError::duplicate_field( + "QueryResponseMessage", + "from", + FROM_TAG, + )); + } + + offset += 1; + let (o, v) = + , A::Ref<'_>> as DataRef<'_, Node>>::decode(&buf[offset..])?; + offset += o; + from = Some(v); + } + FLAGS_BYTE => { + if flags.is_some() { + return Err(DecodeError::duplicate_field( + "QueryResponseMessage", + "flags", + FLAGS_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + flags = Some(QueryFlag::from_bits_retain(v)); + } + PAYLOAD_BYTE => { + if payload.is_some() { + return Err(DecodeError::duplicate_field( + "QueryResponseMessage", + "payload", + PAYLOAD_TAG, + )); + } + + offset += 1; + let (o, v) = <&[u8] as DataRef<'_, Bytes>>::decode(&buf[offset..])?; + offset += o; + payload = Some(v); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + ltime: ltime.ok_or_else(|| DecodeError::missing_field("QueryResponseMessage", "ltime"))?, + id: id.ok_or_else(|| DecodeError::missing_field("QueryResponseMessage", "id"))?, + from: from.ok_or_else(|| DecodeError::missing_field("QueryResponseMessage", "from"))?, + flags: flags.ok_or_else(|| DecodeError::missing_field("QueryResponseMessage", "flags"))?, + payload: payload.unwrap_or_default(), + }, + )) + } +} + +impl Data for QueryResponseMessage +where + I: Data, + A: Data, +{ + type Ref<'a> = QueryResponseMessageRef<'a, I::Ref<'a>, A::Ref<'a>>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(Self { + ltime: val.ltime, + id: val.id, + from: Node::from_ref(val.from)?, + flags: val.flags, + payload: Bytes::copy_from_slice(val.payload), + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 1 + self.ltime.encoded_len(); + len += 1 + self.id.encoded_len(); + len += 1 + self.from.encoded_len_with_length_delimited(); + len += 1 + self.flags.bits().encoded_len(); // flags + let plen = self.payload.len(); + if plen > 0 { + len += 1 + self.payload.encoded_len_with_length_delimited(); + } + + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + } + }; + } + + let buf_len = buf.len(); + let mut offset = 0; + + bail!(self(offset, buf_len)); + buf[offset] = LTIME_BYTE; + offset += 1; + + offset += self + .ltime + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = ID_BYTE; + offset += 1; + + offset += self + .id + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = FROM_BYTE; + offset += 1; + + offset += self + .from + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = FLAGS_BYTE; + offset += 1; + + offset += ::encode(&self.flags.bits(), &mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + if !self.payload.is_empty() { + bail!(self(offset, buf_len)); + buf[offset] = PAYLOAD_BYTE; + offset += 1; + + offset += self + .payload + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + } + + #[cfg(debug_assertions)] + super::super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } +} diff --git a/serf-proto/src/tags.rs b/serf-proto/src/tags.rs index dd18344..5835e17 100644 --- a/serf-proto/src/tags.rs +++ b/serf-proto/src/tags.rs @@ -1,6 +1,13 @@ use indexmap::IndexMap; +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, WireType, + utils::{merge, skip, split}, +}; use smol_str::SmolStr; +const TAGS_TAG: u8 = 1; +const TAGS_BYTE: u8 = merge(WireType::LengthDelimited, TAGS_TAG); + /// Tags of a node #[derive( Debug, @@ -58,3 +65,316 @@ impl Tags { Self(IndexMap::with_capacity(cap)) } } + +/// The reference type to [`Tags`], which is an iterator and yields a reference to the key and value +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct TagsRef<'a> { + src: RepeatedDecoder<'a>, +} + +impl<'a> DataRef<'a, Tags> for TagsRef<'a> { + fn decode(src: &'a [u8]) -> Result<(usize, Self), memberlist_proto::DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = src.len(); + + let mut tags_offsets = None; + let mut num_tags = 0; + + while offset < buf_len { + match src[offset] { + TAGS_BYTE => { + let readed = skip(WireType::LengthDelimited, &src[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = tags_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + tags_offsets = Some((offset - 1, offset + readed)); + } + num_tags += 1; + offset += readed; + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &src[offset..])?; + } + } + } + + let decoder = + RepeatedDecoder::new(TAGS_TAG, WireType::LengthDelimited, src).with_nums(num_tags); + + Ok(( + offset, + Self { + src: if let Some((fnso, lnso)) = tags_offsets { + decoder.with_offsets(fnso, lnso) + } else { + decoder + }, + }, + )) + } +} + +impl Data for Tags { + type Ref<'a> = TagsRef<'a>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + val + .src + .iter::() + .map(|res| res.and_then(|t| Tag::from_ref(t).map(Tag::split))) + .collect::, DecodeError>>() + .map(Self) + } + + fn encoded_len(&self) -> usize { + self + .0 + .iter() + .map(|(k, v)| 1 + TagRef::new(k, v).encoded_len_with_length_delimited()) + .sum::() + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let buf_len = buf.len(); + self + .0 + .iter() + .try_fold(0, |mut offset, (k, v)| { + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer(1, 0)); + } + + buf[offset] = TAGS_BYTE; + offset += 1; + offset += TagRef::new(k, v).encode_with_length_delimited(&mut buf[offset..])?; + Ok(offset) + }) + .map_err(|e: EncodeError| e.update(self.encoded_len(), buf.len())) + } +} + +#[derive(Debug)] +struct Tag { + key: SmolStr, + value: SmolStr, +} + +impl Tag { + fn split(self) -> (SmolStr, SmolStr) { + (self.key, self.value) + } +} + +impl Data for Tag { + type Ref<'a> = TagRef<'a>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(Self { + key: SmolStr::new(val.key), + value: SmolStr::new(val.value), + }) + } + + fn encoded_len(&self) -> usize { + TagRef::new(&self.key, &self.value).encoded_len() + } + + fn encode(&self, buf: &mut [u8]) -> Result { + TagRef::new(&self.key, &self.value).encode(buf) + } +} + +/// A reference to a (key, value) pair of a tag +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TagRef<'a> { + key: &'a str, + value: &'a str, +} + +impl<'a> DataRef<'a, Tag> for TagRef<'a> { + fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = src.len(); + + let mut key = None; + let mut val = None; + + while offset < buf_len { + match src[offset] { + Self::KEY_BYTE => { + if key.is_some() { + return Err(DecodeError::duplicate_field("Tag", "key", Self::KEY_TAG)); + } + offset += 1; + + let (read, value) = + <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; + key = Some(value); + offset += read; + } + Self::VALUE_BYTE => { + if val.is_some() { + return Err(DecodeError::duplicate_field( + "Tag", + "value", + Self::VALUE_TAG, + )); + } + offset += 1; + + let (read, value) = + <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; + val = Some(value); + offset += read; + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &src[offset..])?; + } + } + } + + Ok(( + offset, + Self { + key: key.unwrap_or(""), + value: val.unwrap_or(""), + }, + )) + } +} + +impl<'a> TagRef<'a> { + const KEY_TAG: u8 = 1; + const KEY_BYTE: u8 = merge(WireType::LengthDelimited, Self::KEY_TAG); + const VALUE_TAG: u8 = 2; + const VALUE_BYTE: u8 = merge(WireType::LengthDelimited, Self::VALUE_TAG); + + fn new(key: &'a str, value: &'a str) -> Self { + Self { key, value } + } + + fn encoded_len(&self) -> usize { + let klen = self.key.len(); + let vlen = self.value.len(); + + let mut len = 0; + if klen != 0 { + len += 1 + (klen as u32).encoded_len(); + } + + if vlen != 0 { + len += 1 + (vlen as u32).encoded_len(); + } + + len + } + + fn encoded_len_with_length_delimited(&self) -> usize { + let len = self.encoded_len(); + len + (len as u32).encoded_len() + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let buf_len = buf.len(); + let mut offset = 0; + + if buf_len <= offset { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + let klen = self.key.len(); + if klen != 0 { + buf[offset] = Self::KEY_BYTE; + offset += 1; + + offset += (klen as u32) + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + if buf_len < offset + klen { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + buf[offset..offset + klen].copy_from_slice(self.key.as_bytes()); + offset += klen; + } + + if buf_len <= offset { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + let vlen = self.value.len(); + if vlen != 0 { + buf[offset] = Self::VALUE_BYTE; + offset += 1; + + offset += (vlen as u32) + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + if buf_len < offset + vlen { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset..offset + vlen].copy_from_slice(self.value.as_bytes()); + offset += vlen; + } + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } + + fn encode_with_length_delimited(&self, buf: &mut [u8]) -> Result { + let len = self.encoded_len(); + let buf_len = buf.len(); + if buf_len < len { + return Err(EncodeError::insufficient_buffer(len, buf_len)); + } + + let mut offset = 0; + offset += (len as u32).encode(&mut buf[offset..])?; + offset += self.encode(&mut buf[offset..])?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len_with_length_delimited()); + + Ok(offset) + } +} diff --git a/serf-proto/src/user_event.rs b/serf-proto/src/user_event.rs index af2d281..a929286 100644 --- a/serf-proto/src/user_event.rs +++ b/serf-proto/src/user_event.rs @@ -1,32 +1,21 @@ -use memberlist_proto::{CheapClone, OneOrMore, bytes::Bytes}; +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, WireType, + bytes::Bytes, + utils::{merge, skip, split}, +}; use smol_str::SmolStr; -use super::LamportTime; +pub use message::*; +pub use user_events::*; -/// Used to buffer events to prevent re-delivery -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -pub struct UserEvents { - /// The lamport time - #[viewit( - getter(const, attrs(doc = "Returns the lamport time for this message")), - setter( - const, - attrs(doc = "Sets the lamport time for this message (Builder pattern)") - ) - )] - ltime: LamportTime, +mod message; +mod user_events; - /// The user events - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the user events")), - setter(attrs(doc = "Sets the user events (Builder pattern)")) - )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, OneOrMore>))] - events: OneOrMore, -} +const NAME_TAG: u8 = 1; +const PAYLOAD_TAG: u8 = 2; + +const NAME_BYTE: u8 = merge(WireType::LengthDelimited, NAME_TAG); +const PAYLOAD_BYTE: u8 = merge(WireType::LengthDelimited, PAYLOAD_TAG); /// Stores all the user events at a specific time #[viewit::viewit(getters(style = "ref"), setters(prefix = "with"))] @@ -49,60 +38,124 @@ pub struct UserEvent { payload: Bytes, } -/// Used for user-generated events -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Default, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -pub struct UserEventMessage { - /// The lamport time - #[viewit( - getter( - const, - style = "move", - attrs(doc = "Returns the lamport time for this message") - ), - setter( - const, - attrs(doc = "Sets the lamport time for this message (Builder pattern)") - ) - )] - ltime: LamportTime, - /// The name of the event - #[viewit( - getter(const, attrs(doc = "Returns the name of the event")), - setter(attrs(doc = "Sets the name of the event (Builder pattern)")) - )] - name: SmolStr, - /// The payload of the event - #[viewit( - getter(const, attrs(doc = "Returns the payload of the event")), - setter(attrs(doc = "Sets the payload of the event (Builder pattern)")) - )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] - payload: Bytes, - /// "Can Coalesce". - #[viewit( - getter( - const, - style = "move", - attrs(doc = "Returns if this message can be coalesced") - ), - setter( - const, - attrs(doc = "Sets if this message can be coalesced (Builder pattern)") - ) - )] - cc: bool, +/// The reference to a [`UserEvent`]. +#[viewit::viewit(getters(style = "ref", vis_all = "pub"), setters(skip), vis_all = "")] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct UserEventRef<'a> { + #[viewit(getter(const, attrs(doc = "Returns the name of the event")))] + name: &'a str, + #[viewit(getter(const, attrs(doc = "Returns the payload of the event")))] + payload: &'a [u8], } -impl CheapClone for UserEventMessage { - fn cheap_clone(&self) -> Self { - Self { - ltime: self.ltime, - name: self.name.cheap_clone(), - payload: self.payload.clone(), - cc: self.cc, +impl<'a> DataRef<'a, UserEvent> for UserEventRef<'a> { + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut name = None; + let mut payload = None; + + while offset < buf_len { + match buf[offset] { + NAME_BYTE => { + if name.is_some() { + return Err(DecodeError::duplicate_field("UserEvent", "name", NAME_TAG)); + } + offset += 1; + let (size, val) = + <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&buf[offset..])?; + name = Some(val); + offset += size; + } + PAYLOAD_BYTE => { + if payload.is_some() { + return Err(DecodeError::duplicate_field( + "UserEvent", + "payload", + PAYLOAD_TAG, + )); + } + offset += 1; + let (size, val) = <&[u8] as DataRef<'_, Bytes>>::decode_length_delimited(&buf[offset..])?; + payload = Some(val); + offset += size; + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } } + + Ok(( + offset, + Self { + name: name.unwrap_or_default(), + payload: payload.unwrap_or_default(), + }, + )) + } +} + +impl Data for UserEvent { + type Ref<'a> = UserEventRef<'a>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(Self { + name: SmolStr::from(val.name), + payload: Bytes::copy_from_slice(val.payload), + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + if !self.name.is_empty() { + len += 1 + self.name.encoded_len_with_length_delimited(); + } + + if !self.payload.is_empty() { + len += 1 + self.payload.encoded_len_with_length_delimited(); + } + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + let buf_len = buf.len(); + if !self.name.is_empty() { + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + buf[offset] = NAME_BYTE; + offset += 1; + offset += self.name.encode_length_delimited(&mut buf[offset..])?; + } + + if !self.payload.is_empty() { + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + buf[offset] = PAYLOAD_BYTE; + offset += 1; + offset += self.payload.encode_length_delimited(&mut buf[offset..])?; + } + + Ok(offset) } } diff --git a/serf-proto/src/user_event/message.rs b/serf-proto/src/user_event/message.rs new file mode 100644 index 0000000..7b192b5 --- /dev/null +++ b/serf-proto/src/user_event/message.rs @@ -0,0 +1,270 @@ +use memberlist_proto::{ + CheapClone, Data, DataRef, DecodeError, EncodeError, WireType, + bytes::Bytes, + utils::{merge, skip, split}, +}; +use smol_str::SmolStr; + +use crate::LamportTime; + +/// Used for user-generated events +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Default, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct UserEventMessage { + /// The lamport time + #[viewit( + getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for this message") + ), + setter( + const, + attrs(doc = "Sets the lamport time for this message (Builder pattern)") + ) + )] + ltime: LamportTime, + /// The name of the event + #[viewit( + getter(const, attrs(doc = "Returns the name of the event")), + setter(attrs(doc = "Sets the name of the event (Builder pattern)")) + )] + name: SmolStr, + /// The payload of the event + #[viewit( + getter(const, attrs(doc = "Returns the payload of the event")), + setter(attrs(doc = "Sets the payload of the event (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + payload: Bytes, + /// "Can Coalesce". + #[viewit( + getter( + const, + style = "move", + attrs(doc = "Returns if this message can be coalesced") + ), + setter( + const, + attrs(doc = "Sets if this message can be coalesced (Builder pattern)") + ) + )] + cc: bool, +} + +impl CheapClone for UserEventMessage { + fn cheap_clone(&self) -> Self { + Self { + ltime: self.ltime, + name: self.name.cheap_clone(), + payload: self.payload.clone(), + cc: self.cc, + } + } +} + +const LTIME_TAG: u8 = 1; +const CC_TAG: u8 = 2; +const NAME_TAG: u8 = 3; +const PAYLOAD_TAG: u8 = 4; + +const LTIME_BYTE: u8 = merge(WireType::Varint, LTIME_TAG); +const CC_BYTE: u8 = merge(WireType::Byte, CC_TAG); +const NAME_BYTE: u8 = merge(WireType::LengthDelimited, NAME_TAG); +const PAYLOAD_BYTE: u8 = merge(WireType::LengthDelimited, PAYLOAD_TAG); + +/// The reference type of [`UserEventMessage`] +#[viewit::viewit(vis_all = "", getters(vis_all = "pub", style = "ref"), setters(skip))] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct UserEventMessageRef<'a> { + /// The lamport time + #[viewit(getter(const, attrs(doc = "Returns the lamport time for this message")))] + ltime: LamportTime, + /// The name of the event + #[viewit(getter(const, attrs(doc = "Returns the name of the event")))] + name: &'a str, + /// The payload of the event + #[viewit(getter(const, attrs(doc = "Returns the payload of the event")))] + payload: &'a [u8], + /// "Can Coalesce". + #[viewit(getter(const, attrs(doc = "Returns if this message can be coalesced")))] + cc: bool, +} + +impl<'a> DataRef<'a, UserEventMessage> for UserEventMessageRef<'a> { + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut ltime = None; + let mut name = None; + let mut payload = None; + let mut cc = None; + + while offset < buf_len { + match buf[offset] { + LTIME_BYTE => { + if ltime.is_some() { + return Err(DecodeError::duplicate_field( + "UserEventMessage", + "ltime", + LTIME_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + ltime = Some(v); + } + CC_BYTE => { + if cc.is_some() { + return Err(DecodeError::duplicate_field( + "UserEventMessage", + "cc", + CC_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + cc = Some(v); + } + NAME_BYTE => { + if name.is_some() { + return Err(DecodeError::duplicate_field( + "UserEventMessage", + "name", + NAME_TAG, + )); + } + offset += 1; + let (o, v) = <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&buf[offset..])?; + offset += o; + name = Some(v); + } + PAYLOAD_BYTE => { + if payload.is_some() { + return Err(DecodeError::duplicate_field( + "UserEventMessage", + "payload", + PAYLOAD_TAG, + )); + } + + offset += 1; + let (o, v) = <&[u8] as DataRef<'_, Bytes>>::decode_length_delimited(&buf[offset..])?; + offset += o; + payload = Some(v); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + ltime: ltime.ok_or_else(|| DecodeError::missing_field("UserEventMessage", "ltime"))?, + name: name.unwrap_or_default(), + payload: payload.unwrap_or_default(), + cc: cc.unwrap_or_default(), + }, + )) + } +} + +impl Data for UserEventMessage { + type Ref<'a> = UserEventMessageRef<'a>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(Self { + ltime: val.ltime, + name: SmolStr::from(val.name), + payload: Bytes::copy_from_slice(val.payload), + cc: val.cc, + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 1 + self.ltime.encoded_len(); + let nlen = self.name.len(); + if nlen > 0 { + len += 1 + self.name.encoded_len_with_length_delimited(); + } + + let plen = self.payload.len(); + if plen > 0 { + len += 1 + self.payload.encoded_len_with_length_delimited(); + } + + if self.cc { + len += 1 + 1; + } + + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + } + }; + } + + let buf_len = buf.len(); + let mut offset = 0; + + bail!(self(offset, buf_len)); + buf[offset] = LTIME_BYTE; + offset += 1; + offset += self + .ltime + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + if self.cc { + bail!(self(offset, buf_len)); + buf[offset] = CC_BYTE; + offset += 1; + bail!(self(offset, buf_len)); + buf[offset] = 1; + offset += 1; + } + + if !self.name.is_empty() { + bail!(self(offset, buf_len)); + buf[offset] = NAME_BYTE; + offset += 1; + offset += self.name.encode_length_delimited(&mut buf[offset..])?; + } + + if !self.payload.is_empty() { + bail!(self(offset, buf_len)); + buf[offset] = PAYLOAD_BYTE; + offset += 1; + offset += self.payload.encode_length_delimited(&mut buf[offset..])?; + } + + #[cfg(debug_assertions)] + super::super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } +} diff --git a/serf-proto/src/user_event/user_events.rs b/serf-proto/src/user_event/user_events.rs new file mode 100644 index 0000000..17f02d1 --- /dev/null +++ b/serf-proto/src/user_event/user_events.rs @@ -0,0 +1,182 @@ +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, OneOrMore, RepeatedDecoder, WireType, + utils::{merge, skip, split}, +}; + +use crate::LamportTime; + +use super::UserEvent; + +const LTIME_TAG: u8 = 1; +const EVENTS_TAG: u8 = 2; + +const LTIME_BYTE: u8 = merge(WireType::Varint, LTIME_TAG); +const EVENTS_BYTE: u8 = merge(WireType::LengthDelimited, EVENTS_TAG); + +/// Used to buffer events to prevent re-delivery +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct UserEvents { + /// The lamport time + #[viewit( + getter(const, attrs(doc = "Returns the lamport time for this message")), + setter( + const, + attrs(doc = "Sets the lamport time for this message (Builder pattern)") + ) + )] + ltime: LamportTime, + + /// The user events + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the user events")), + setter(attrs(doc = "Sets the user events (Builder pattern)")) + )] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, OneOrMore>))] + events: OneOrMore, +} + +/// The reference type for [`UserEvents`] +#[viewit::viewit(vis_all = "", getters(vis_all = "pub", style = "ref"), setters(skip))] +#[derive(Debug, Clone, Copy)] +pub struct UserEventsRef<'a> { + /// The lamport time + #[viewit(getter(const, attrs(doc = "Returns the lamport time for this message")))] + ltime: LamportTime, + + /// The user events + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the bu user events")))] + events: RepeatedDecoder<'a>, +} + +impl<'a> DataRef<'a, UserEvents> for UserEventsRef<'a> { + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut ltime = None; + let mut events_offsets = None; + let mut num_events = 0; + + while offset < buf_len { + match buf[offset] { + LTIME_BYTE => { + if ltime.is_some() { + return Err(DecodeError::duplicate_field( + "UserEvents", + "ltime", + LTIME_TAG, + )); + } + offset += 1; + let (size, val) = >::decode(&buf[offset..])?; + ltime = Some(val); + offset += size; + } + EVENTS_BYTE => { + let readed = super::skip(WireType::LengthDelimited, &buf[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = events_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + events_offsets = Some((offset - 1, offset + readed)); + } + num_events += 1; + offset += readed; + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + ltime: ltime.ok_or_else(|| DecodeError::missing_field("UserEvents", "ltime"))?, + events: if let Some((start, end)) = events_offsets { + RepeatedDecoder::new(EVENTS_TAG, WireType::LengthDelimited, buf) + .with_nums(num_events) + .with_offsets(start, end) + } else { + RepeatedDecoder::new(EVENTS_TAG, WireType::LengthDelimited, buf) + }, + }, + )) + } +} + +impl Data for UserEvents { + type Ref<'a> = UserEventsRef<'a>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + val + .events + .iter::() + .map(|ev| ev.and_then(UserEvent::from_ref)) + .collect::, DecodeError>>() + .map(|events| Self { + ltime: val.ltime, + events, + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + len += 1 + self.ltime.encoded_len(); + len += self + .events + .iter() + .map(|e| 1 + e.encoded_len_with_length_delimited()) + .sum::(); + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let buf_len = buf.len(); + let mut offset = 0; + + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[offset] = LTIME_BYTE; + offset += 1; + + offset += self + .ltime + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + + self + .events + .iter() + .try_fold(&mut offset, |offset, ev| { + *offset += ev.encode_length_delimited(&mut buf[*offset..])?; + + Ok(offset) + }) + .map(|offset| *offset) + .map_err(|e: EncodeError| e.update(self.encoded_len(), buf_len)) + } +} diff --git a/serf-proto/src/version.rs b/serf-proto/src/version.rs index dbbff00..bb110a6 100644 --- a/serf-proto/src/version.rs +++ b/serf-proto/src/version.rs @@ -5,7 +5,6 @@ use memberlist_proto::{Data, DataRef, DecodeError, EncodeError, WireType}; Debug, Default, Copy, Clone, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display, )] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[non_exhaustive] pub enum DelegateVersion { /// Version 1 @@ -40,7 +39,6 @@ impl From for u8 { Debug, Default, Copy, Clone, PartialEq, Eq, Hash, derive_more::IsVariant, derive_more::Display, )] #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[non_exhaustive] pub enum ProtocolVersion { /// Version 1 diff --git a/serf/src/lib.rs b/serf/src/lib.rs index 8a1026e..352ea93 100644 --- a/serf/src/lib.rs +++ b/serf/src/lib.rs @@ -1,32 +1,32 @@ -#![doc = include_str!("../../README.md")] -#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] -#![forbid(unsafe_code)] -#![deny(warnings, missing_docs)] -#![allow(clippy::type_complexity)] -#![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(docsrs, allow(unused_attributes))] +// #![doc = include_str!("../../README.md")] +// #![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] +// #![forbid(unsafe_code)] +// #![deny(warnings, missing_docs)] +// #![allow(clippy::type_complexity)] +// #![cfg_attr(docsrs, feature(doc_cfg))] +// #![cfg_attr(docsrs, allow(unused_attributes))] -pub use serf_core::*; +// pub use serf_core::*; -pub use memberlist::{agnostic, transport}; +// pub use memberlist::{agnostic, transport}; -#[cfg(feature = "net")] -pub use memberlist::net; +// #[cfg(feature = "net")] +// pub use memberlist::net; -#[cfg(feature = "quic")] -pub use memberlist::quic; +// #[cfg(feature = "quic")] +// pub use memberlist::quic; -/// [`Serf`](serf_core::Serf) for `tokio` runtime. -#[cfg(feature = "tokio")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] -pub mod tokio; +// /// [`Serf`](serf_core::Serf) for `tokio` runtime. +// #[cfg(feature = "tokio")] +// #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] +// pub mod tokio; -/// [`Serf`](serf_core::Serf) for `async-std` runtime. -#[cfg(feature = "async-std")] -#[cfg_attr(docsrs, doc(cfg(feature = "async-std")))] -pub mod async_std; +// /// [`Serf`](serf_core::Serf) for `async-std` runtime. +// #[cfg(feature = "async-std")] +// #[cfg_attr(docsrs, doc(cfg(feature = "async-std")))] +// pub mod async_std; -/// [`Serf`](serf_core::Serf) for `smol` runtime. -#[cfg(feature = "smol")] -#[cfg_attr(docsrs, doc(cfg(feature = "smol")))] -pub mod smol; +// /// [`Serf`](serf_core::Serf) for `smol` runtime. +// #[cfg(feature = "smol")] +// #[cfg_attr(docsrs, doc(cfg(feature = "smol")))] +// pub mod smol; From c59c656ba5b3031f25c282be4bd3ab5ff6ae82cf Mon Sep 17 00:00:00 2001 From: al8n Date: Wed, 26 Feb 2025 22:59:43 +0800 Subject: [PATCH 04/39] WIP --- serf-core/src/delegate.rs | 4 - serf-core/src/delegate/composite.rs | 138 +----- serf-core/src/delegate/transform.rs | 304 ------------ serf-core/src/serf/base.rs | 11 + serf-proto/src/key.rs | 305 +++++++++++- serf-proto/src/lib.rs | 4 +- serf-proto/src/message.rs | 740 ++++++++++++++-------------- serf-proto/src/push_pull.rs | 362 ++++++++++++-- serf-proto/src/tags.rs | 428 ++++++++-------- 9 files changed, 1238 insertions(+), 1058 deletions(-) delete mode 100644 serf-core/src/delegate/transform.rs diff --git a/serf-core/src/delegate.rs b/serf-core/src/delegate.rs index 48288fc..563ed2b 100644 --- a/serf-core/src/delegate.rs +++ b/serf-core/src/delegate.rs @@ -6,9 +6,6 @@ pub use merge::*; mod reconnect; pub use reconnect::*; -mod transform; -pub use transform::*; - mod composite; pub use composite::*; @@ -17,7 +14,6 @@ pub use composite::*; /// as they can and generally will be called concurrently. pub trait Delegate: MergeDelegate::Id, Address = ::Address> - + TransformDelegate::Id, Address = ::Address> + ReconnectDelegate::Id, Address = ::Address> { /// The id type of the delegate diff --git a/serf-core/src/delegate/composite.rs b/serf-core/src/delegate/composite.rs index e44e81c..39926fe 100644 --- a/serf-core/src/delegate/composite.rs +++ b/serf-core/src/delegate/composite.rs @@ -23,11 +23,9 @@ pub struct CompositeDelegate< A, M = DefaultMergeDelegate, R = NoopReconnectDelegate, - T = LpeTransfromDelegate, > { merge: M, reconnect: R, - transform: T, _m: std::marker::PhantomData<(I, A)>, } @@ -43,13 +41,12 @@ impl CompositeDelegate { Self { merge: Default::default(), reconnect: Default::default(), - transform: Default::default(), _m: std::marker::PhantomData, } } } -impl CompositeDelegate +impl CompositeDelegate where M: MergeDelegate, { @@ -58,43 +55,28 @@ where CompositeDelegate { merge, reconnect: self.reconnect, - transform: self.transform, _m: std::marker::PhantomData, } } } -impl CompositeDelegate { +impl CompositeDelegate { /// Set the [`ReconnectDelegate`] for the `CompositeDelegate`. pub fn with_reconnect_delegate(self, reconnect: NR) -> CompositeDelegate { CompositeDelegate { reconnect, merge: self.merge, - transform: self.transform, _m: std::marker::PhantomData, } } } -impl CompositeDelegate { - /// Set the [`TransformDelegate`] for the `CompositeDelegate`. - pub fn with_transform_delegate(self, transform: NT) -> CompositeDelegate { - CompositeDelegate { - transform, - merge: self.merge, - reconnect: self.reconnect, - _m: std::marker::PhantomData, - } - } -} - -impl MergeDelegate for CompositeDelegate +impl MergeDelegate for CompositeDelegate where I: Id, A: CheapClone + Send + Sync + 'static, M: MergeDelegate, R: Send + Sync + 'static, - T: Send + Sync + 'static, { type Error = M::Error; @@ -110,13 +92,12 @@ where } } -impl ReconnectDelegate for CompositeDelegate +impl ReconnectDelegate for CompositeDelegate where I: Id, A: CheapClone + Send + Sync + 'static, M: Send + Sync + 'static, R: ReconnectDelegate, - T: Send + Sync + 'static, { type Id = R::Id; @@ -131,121 +112,12 @@ where } } -impl TransformDelegate for CompositeDelegate -where - I: Id, - A: CheapClone + Send + Sync + 'static, - M: Send + Sync + 'static, - R: Send + Sync + 'static, - T: TransformDelegate, -{ - type Error = T::Error; - - type Id = T::Id; - - type Address = T::Address; - - fn encode_filter( - filter: &Filter, - ) -> Result { - T::encode_filter(filter) - } - - fn decode_filter(bytes: &[u8]) -> Result<(usize, Filter), Self::Error> { - T::decode_filter(bytes) - } - - fn node_encoded_len(node: &Node) -> usize { - T::node_encoded_len(node) - } - - fn encode_node( - node: &Node, - dst: &mut [u8], - ) -> Result { - T::encode_node(node, dst) - } - - fn decode_node( - bytes: impl AsRef<[u8]>, - ) -> Result<(usize, Node), Self::Error> { - T::decode_node(bytes) - } - - fn id_encoded_len(id: &Self::Id) -> usize { - T::id_encoded_len(id) - } - - fn encode_id(id: &Self::Id, dst: &mut [u8]) -> Result { - T::encode_id(id, dst) - } - - fn decode_id(bytes: &[u8]) -> Result<(usize, Self::Id), Self::Error> { - T::decode_id(bytes) - } - - fn address_encoded_len(address: &Self::Address) -> usize { - T::address_encoded_len(address) - } - - fn encode_address(address: &Self::Address, dst: &mut [u8]) -> Result { - T::encode_address(address, dst) - } - - fn decode_address(bytes: &[u8]) -> Result<(usize, Self::Address), Self::Error> { - T::decode_address(bytes) - } - - fn coordinate_encoded_len(coordinate: &Coordinate) -> usize { - T::coordinate_encoded_len(coordinate) - } - - fn encode_coordinate(coordinate: &Coordinate, dst: &mut [u8]) -> Result { - T::encode_coordinate(coordinate, dst) - } - - fn decode_coordinate(bytes: &[u8]) -> Result<(usize, Coordinate), Self::Error> { - T::decode_coordinate(bytes) - } - - fn tags_encoded_len(tags: &Tags) -> usize { - T::tags_encoded_len(tags) - } - - fn encode_tags(tags: &Tags, dst: &mut [u8]) -> Result { - T::encode_tags(tags, dst) - } - - fn decode_tags(bytes: &[u8]) -> Result<(usize, Tags), Self::Error> { - T::decode_tags(bytes) - } - - fn message_encoded_len(msg: impl AsMessageRef) -> usize { - T::message_encoded_len(msg) - } - - fn encode_message( - msg: impl AsMessageRef, - dst: impl AsMut<[u8]>, - ) -> Result { - T::encode_message(msg, dst) - } - - fn decode_message( - ty: MessageType, - bytes: impl AsRef<[u8]>, - ) -> Result<(usize, SerfMessage), Self::Error> { - T::decode_message(ty, bytes) - } -} - -impl Delegate for CompositeDelegate +impl Delegate for CompositeDelegate where I: Id, A: CheapClone + Send + Sync + 'static, M: MergeDelegate, R: ReconnectDelegate, - T: TransformDelegate, { type Id = I; diff --git a/serf-core/src/delegate/transform.rs b/serf-core/src/delegate/transform.rs deleted file mode 100644 index d357466..0000000 --- a/serf-core/src/delegate/transform.rs +++ /dev/null @@ -1,304 +0,0 @@ -use memberlist_core::{ - CheapClone, - bytes::Bytes, - transport::{Id, Node, Transformable}, -}; -use serf_proto::{ - FilterTransformError, JoinMessage, LeaveMessage, Member, MessageType, NodeTransformError, - PushPullMessage, QueryMessage, QueryResponseMessage, SerfMessageTransformError, - TagsTransformError, UserEventMessage, -}; - -use crate::{ - coordinate::{Coordinate, CoordinateTransformError}, - types::{AsMessageRef, Filter, SerfMessage, Tags, UnknownMessageType}, -}; - -/// A delegate for encoding and decoding. -pub trait TransformDelegate: Send + Sync + 'static { - /// The error type for the transformation. - type Error: std::error::Error + From + Send + Sync + 'static; - /// The Id type. - type Id: Id; - /// The Address type. - type Address: CheapClone + Send + Sync + 'static; - - /// Encodes the filter into bytes. - fn encode_filter(filter: &Filter) -> Result; - - /// Decodes the filter from the given bytes, returning the number of bytes consumed and the filter. - fn decode_filter(bytes: &[u8]) -> Result<(usize, Filter), Self::Error>; - - /// Returns the encoded length of the node. - fn node_encoded_len(node: &Node) -> usize; - - /// Encodes the node into the given buffer, returning the number of bytes written. - fn encode_node( - node: &Node, - dst: &mut [u8], - ) -> Result; - - /// Decodes [`Node`] from the given bytes, returning the number of bytes consumed and the node. - fn decode_node( - bytes: impl AsRef<[u8]>, - ) -> Result<(usize, Node), Self::Error>; - - /// Returns the encoded length of the id. - fn id_encoded_len(id: &Self::Id) -> usize; - - /// Encodes the id into the given buffer, returning the number of bytes written. - fn encode_id(id: &Self::Id, dst: &mut [u8]) -> Result; - - /// Decodes the id from the given bytes, returning the number of bytes consumed and the id. - fn decode_id(bytes: &[u8]) -> Result<(usize, Self::Id), Self::Error>; - - /// Returns the encoded length of the address. - fn address_encoded_len(address: &Self::Address) -> usize; - - /// Encodes the address into the given buffer, returning the number of bytes written. - fn encode_address(address: &Self::Address, dst: &mut [u8]) -> Result; - - /// Decodes the address from the given bytes, returning the number of bytes consumed and the address. - fn decode_address(bytes: &[u8]) -> Result<(usize, Self::Address), Self::Error>; - - /// Encoded length of the coordinate. - fn coordinate_encoded_len(coordinate: &Coordinate) -> usize; - - /// Encodes the coordinate into the given buffer, returning the number of bytes written. - fn encode_coordinate(coordinate: &Coordinate, dst: &mut [u8]) -> Result; - - /// Decodes the coordinate from the given bytes, returning the number of bytes consumed and the coordinate. - fn decode_coordinate(bytes: &[u8]) -> Result<(usize, Coordinate), Self::Error>; - - /// Encoded length of the tags. - fn tags_encoded_len(tags: &Tags) -> usize; - - /// Encodes the tags into the given buffer, returning the number of bytes written. - fn encode_tags(tags: &Tags, dst: &mut [u8]) -> Result; - - /// Decodes the tags from the given bytes, returning the number of bytes consumed and the tags. - fn decode_tags(bytes: &[u8]) -> Result<(usize, Tags), Self::Error>; - - /// Encoded length of the message. - fn message_encoded_len(msg: impl AsMessageRef) -> usize; - - /// Encodes the message into the given buffer, returning the number of bytes written. - /// - /// **NOTE**: - /// - /// 1. The buffer must be large enough to hold the encoded message. - /// The length of the buffer can be obtained by calling [`TransformDelegate::message_encoded_len`]. - /// 2. A message type byte will be automatically prepended to the buffer, - /// so users do not need to encode the message type byte by themselves. - fn encode_message( - msg: impl AsMessageRef, - dst: impl AsMut<[u8]>, - ) -> Result; - - /// Decodes the message from the given bytes, returning the number of bytes consumed and the message. - fn decode_message( - ty: MessageType, - bytes: impl AsRef<[u8]>, - ) -> Result<(usize, SerfMessage), Self::Error>; -} - -/// The error type for the LPE transformation. -#[derive(thiserror::Error)] -pub enum LpeTransformError -where - I: Transformable + core::hash::Hash + Eq, - A: Transformable + core::hash::Hash + Eq, -{ - /// Id transformation error. - #[error(transparent)] - Id(::Error), - /// Address transformation error. - #[error(transparent)] - Address(::Error), - /// Coordinate transformation error. - #[error(transparent)] - Coordinate(#[from] CoordinateTransformError), - /// Node transformation error. - #[error(transparent)] - Node(#[from] NodeTransformError), - /// Filter transformation error. - #[error(transparent)] - Filter(#[from] FilterTransformError), - /// Tags transformation error. - #[error(transparent)] - Tags(#[from] TagsTransformError), - /// Serf message transformation error. - #[error(transparent)] - Message(#[from] SerfMessageTransformError), - /// Unknown message type error. - #[error(transparent)] - UnknownMessage(#[from] UnknownMessageType), - /// Unexpected relay message. - #[error("unexpected relay message")] - UnexpectedRelayMessage, -} - -impl core::fmt::Debug for LpeTransformError -where - I: Transformable + core::hash::Hash + Eq, - A: Transformable + core::hash::Hash + Eq, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -/// A length-prefixed encoding [`TransformDelegate`] implementation -pub struct LpeTransfromDelegate(std::marker::PhantomData<(I, A)>); - -impl Default for LpeTransfromDelegate { - fn default() -> Self { - Self(Default::default()) - } -} - -impl Clone for LpeTransfromDelegate { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for LpeTransfromDelegate {} - -impl TransformDelegate for LpeTransfromDelegate -where - I: Id, - A: Transformable + CheapClone + core::hash::Hash + Eq + Send + Sync + 'static, -{ - type Error = LpeTransformError; - type Id = I; - type Address = A; - - fn encode_filter(filter: &Filter) -> Result { - filter - .encode_to_vec() - .map(Bytes::from) - .map_err(Self::Error::Filter) - } - - fn decode_filter(bytes: &[u8]) -> Result<(usize, Filter), Self::Error> { - Filter::decode(bytes).map_err(Self::Error::Filter) - } - - fn node_encoded_len(node: &Node) -> usize { - Transformable::encoded_len(node) - } - - fn encode_node( - node: &Node, - dst: &mut [u8], - ) -> Result { - Transformable::encode(node, dst).map_err(Self::Error::Node) - } - - fn decode_node( - bytes: impl AsRef<[u8]>, - ) -> Result<(usize, Node), Self::Error> { - Transformable::decode(bytes.as_ref()).map_err(Self::Error::Node) - } - - fn id_encoded_len(id: &Self::Id) -> usize { - Transformable::encoded_len(id) - } - - fn encode_id(id: &Self::Id, dst: &mut [u8]) -> Result { - Transformable::encode(id, dst).map_err(Self::Error::Id) - } - - fn decode_id(bytes: &[u8]) -> Result<(usize, Self::Id), Self::Error> { - Transformable::decode(bytes).map_err(Self::Error::Id) - } - - fn address_encoded_len(address: &Self::Address) -> usize { - Transformable::encoded_len(address) - } - - fn encode_address(address: &Self::Address, dst: &mut [u8]) -> Result { - Transformable::encode(address, dst).map_err(Self::Error::Address) - } - - fn decode_address(bytes: &[u8]) -> Result<(usize, Self::Address), Self::Error> { - Transformable::decode(bytes).map_err(Self::Error::Address) - } - - fn coordinate_encoded_len(coordinate: &Coordinate) -> usize { - Transformable::encoded_len(coordinate) - } - - fn encode_coordinate(coordinate: &Coordinate, dst: &mut [u8]) -> Result { - Transformable::encode(coordinate, dst).map_err(Self::Error::Coordinate) - } - - fn decode_coordinate(bytes: &[u8]) -> Result<(usize, Coordinate), Self::Error> { - Transformable::decode(bytes).map_err(Self::Error::Coordinate) - } - - fn tags_encoded_len(tags: &Tags) -> usize { - Transformable::encoded_len(tags) - } - - fn encode_tags(tags: &Tags, dst: &mut [u8]) -> Result { - Transformable::encode(tags, dst).map_err(Self::Error::Tags) - } - - fn decode_tags(bytes: &[u8]) -> Result<(usize, Tags), Self::Error> { - Transformable::decode(bytes).map_err(Self::Error::Tags) - } - - fn message_encoded_len(msg: impl AsMessageRef) -> usize { - let msg = msg.as_message_ref(); - serf_proto::Encodable::encoded_len(&msg) - } - - fn encode_message( - msg: impl AsMessageRef, - mut dst: impl AsMut<[u8]>, - ) -> Result { - let msg = msg.as_message_ref(); - serf_proto::Encodable::encode(&msg, dst.as_mut()).map_err(Into::into) - } - - fn decode_message( - ty: MessageType, - bytes: impl AsRef<[u8]>, - ) -> Result<(usize, SerfMessage), Self::Error> { - match ty { - MessageType::Leave => LeaveMessage::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::Leave(m))) - .map_err(|e| Self::Error::Message(e.into())), - MessageType::Join => JoinMessage::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::Join(m))) - .map_err(|e| Self::Error::Message(e.into())), - MessageType::PushPull => PushPullMessage::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::PushPull(m))) - .map_err(|e| Self::Error::Message(e.into())), - MessageType::UserEvent => UserEventMessage::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::UserEvent(m))) - .map_err(|e| Self::Error::Message(e.into())), - MessageType::Query => QueryMessage::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::Query(m))) - .map_err(|e| Self::Error::Message(e.into())), - MessageType::QueryResponse => QueryResponseMessage::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::QueryResponse(m))) - .map_err(|e| Self::Error::Message(e.into())), - MessageType::ConflictResponse => Member::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::ConflictResponse(m))) - .map_err(|e| Self::Error::Message(e.into())), - MessageType::Relay => Err(Self::Error::UnexpectedRelayMessage), - #[cfg(feature = "encryption")] - MessageType::KeyRequest => serf_proto::KeyRequestMessage::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::KeyRequest(m))) - .map_err(|e| Self::Error::Message(e.into())), - #[cfg(feature = "encryption")] - MessageType::KeyResponse => serf_proto::KeyResponseMessage::decode(bytes.as_ref()) - .map(|(n, m)| (n, SerfMessage::KeyResponse(m))) - .map_err(|e| Self::Error::Message(e.into())), - _ => unreachable!(), - } - } -} diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index 3959819..56b90e6 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -1524,6 +1524,17 @@ where true } + status => { + tracing::warn!(status=%status, "serf: received leave intent for unknown member status"); + member.member.status = MemberStatus::Leaving; + + if msg.prune { + let owned = member.clone(); + drop(members_mut); + self.handle_prune(&owned, *members.borrow_mut()).await; + } + true + } } } diff --git a/serf-proto/src/key.rs b/serf-proto/src/key.rs index 836ecce..211f2a4 100644 --- a/serf-proto/src/key.rs +++ b/serf-proto/src/key.rs @@ -1,5 +1,8 @@ use indexmap::IndexMap; -use memberlist_proto::{SecretKey, SecretKeys}; +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, SecretKey, SecretKeys, WireType, + utils::{merge, skip, split}, +}; use smol_str::SmolStr; /// KeyRequest is used to contain input parameters which get broadcasted to all @@ -19,6 +22,82 @@ pub struct KeyRequestMessage { key: Option, } +const KEY_REQ_KEY_TAG: u8 = 1; +const KEY_REQ_KEY_BYTE: u8 = merge(WireType::LengthDelimited, KEY_REQ_KEY_TAG); + +impl DataRef<'_, Self> for KeyRequestMessage { + fn decode(buf: &'_ [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut key = None; + + while offset < buf_len { + match buf[offset] { + KEY_REQ_KEY_BYTE => { + offset += 1; + + let (bytes_read, val) = ::decode_length_delimited(&buf[offset..])?; + offset += bytes_read; + key = Some(val); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok((offset, Self { key })) + } +} + +impl Data for KeyRequestMessage { + type Ref<'a> = Self; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(val) + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + if let Some(key) = &self.key { + len += 1 + key.encoded_len_with_length_delimited(); + } + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let buf_len = buf.len(); + let mut offset = 0; + + if let Some(key) = &self.key { + if buf_len < offset + 1 { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + buf[offset] = KEY_REQ_KEY_BYTE; + offset += 1; + + let bytes_written = key.encode_length_delimited(&mut buf[offset..])?; + offset += bytes_written; + } + + Ok(offset) + } +} + /// Key response message #[viewit::viewit(setters(prefix = "with"))] #[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] @@ -68,6 +147,226 @@ impl KeyResponseMessage { } } +const KEY_RESPONSE_RESULT_TAG: u8 = 1; +const KEY_RESPONSE_RESULT_BYTE: u8 = merge(WireType::Byte, KEY_RESPONSE_RESULT_TAG); +const KEY_RESPONSE_MESSAGE_TAG: u8 = 2; +const KEY_RESPONSE_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, KEY_RESPONSE_MESSAGE_TAG); +const KEY_RESPONSE_KEYS_TAG: u8 = 3; +const KEY_RESPONSE_KEYS_BYTE: u8 = merge(WireType::LengthDelimited, KEY_RESPONSE_KEYS_TAG); +const KEY_RESPONSE_PRIMARY_KEY_TAG: u8 = 4; +const KEY_RESPONSE_PRIMARY_KEY_BYTE: u8 = + merge(WireType::LengthDelimited, KEY_RESPONSE_PRIMARY_KEY_TAG); + +/// The reference type for [`KeyResponseMessage`]. +#[viewit::viewit(getters(style = "ref", vis_all = "pub"), setters(skip), vis_all = "")] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct KeyResponseMessageRef<'a> { + #[viewit(getter(const, attrs(doc = "Returns true/false if there were errors or not")))] + result: bool, + #[viewit(getter(const, attrs(doc = "Returns the error messages or other information")))] + message: &'a str, + #[viewit(getter(const, attrs(doc = "Returns a list of installed keys")))] + keys: RepeatedDecoder<'a>, + #[viewit(getter(const, attrs(doc = "Returns the primary key")))] + primary_key: Option, +} + +impl<'a> DataRef<'a, KeyResponseMessage> for KeyResponseMessageRef<'a> { + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut result = None; + let mut message = None; + let mut keys_offsets = None; + let mut num_keys = 0; + let mut primary_key = None; + + while offset < buf_len { + match buf[offset] { + KEY_RESPONSE_RESULT_BYTE => { + if result.is_some() { + return Err(DecodeError::duplicate_field( + "KeyResponseMessage", + "result", + KEY_RESPONSE_RESULT_TAG, + )); + } + offset += 1; + if offset >= buf_len { + return Err(DecodeError::buffer_underflow()); + } + result = Some(buf[offset] != 0); + offset += 1; + } + KEY_RESPONSE_MESSAGE_BYTE => { + if message.is_some() { + return Err(DecodeError::duplicate_field( + "KeyResponseMessage", + "message", + KEY_RESPONSE_MESSAGE_TAG, + )); + } + + offset += 1; + let (bytes_read, val) = + <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&buf[offset..])?; + offset += bytes_read; + message = Some(val); + } + KEY_RESPONSE_KEYS_BYTE => { + offset += 1; + let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = keys_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + keys_offsets = Some((offset - 1, offset + readed)); + } + num_keys += 1; + offset += readed; + } + KEY_RESPONSE_PRIMARY_KEY_BYTE => { + if primary_key.is_some() { + return Err(DecodeError::duplicate_field( + "KeyResponseMessage", + "primary_key", + KEY_RESPONSE_PRIMARY_KEY_TAG, + )); + } + + offset += 1; + let (bytes_read, val) = ::decode_length_delimited(&buf[offset..])?; + offset += bytes_read; + primary_key = Some(val); + } + other => { + offset += 1; + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + result: result.unwrap_or_default(), + message: message.unwrap_or_default(), + keys: if let Some((start, end)) = keys_offsets { + RepeatedDecoder::new(KEY_RESPONSE_KEYS_TAG, WireType::LengthDelimited, buf) + .with_nums(num_keys) + .with_offsets(start, end) + } else { + RepeatedDecoder::new(KEY_RESPONSE_KEYS_TAG, WireType::LengthDelimited, buf) + }, + primary_key, + }, + )) + } +} + +impl Data for KeyResponseMessage { + type Ref<'a> = KeyResponseMessageRef<'a>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + val + .keys + .iter::() + .map(|res| res.and_then(Data::from_ref)) + .collect::>() + .map(|keys| Self { + result: val.result, + message: SmolStr::new(val.message), + keys, + primary_key: val.primary_key, + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + if self.result { + len += 1 + 1; + } + + if !self.message.is_empty() { + len += 1 + self.message.encoded_len_with_length_delimited(); + } + + len += self + .keys + .iter() + .map(|key| 1 + key.encoded_len_with_length_delimited()) + .sum::(); + + if let Some(key) = &self.primary_key { + len += 1 + key.encoded_len_with_length_delimited(); + } + + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + } + }; + } + + let buf_len = buf.len(); + let mut offset = 0; + + if self.result { + bail!(self(offset, buf_len)); + buf[offset] = KEY_RESPONSE_RESULT_BYTE; + offset += 1; + bail!(self(offset, buf_len)); + buf[offset] = 1; + offset += 1; + } + + if !self.message.is_empty() { + bail!(self(offset, buf_len)); + buf[offset] = KEY_RESPONSE_MESSAGE_BYTE; + offset += 1; + offset += self.message.encode_length_delimited(&mut buf[offset..])?; + } + + for key in self.keys.iter() { + bail!(self(offset, buf_len)); + buf[offset] = KEY_RESPONSE_KEYS_BYTE; + offset += 1; + offset += key.encode_length_delimited(&mut buf[offset..])?; + } + + if let Some(key) = &self.primary_key { + bail!(self(offset, buf_len)); + buf[offset] = KEY_RESPONSE_PRIMARY_KEY_BYTE; + offset += 1; + offset += key.encode_length_delimited(&mut buf[offset..])?; + } + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } +} + /// KeyResponse is used to relay a query for a list of all keys in use. #[viewit::viewit(setters(prefix = "with"))] #[derive(Default)] @@ -124,7 +423,7 @@ pub struct KeyResponse { doc = "Sets a mapping of the value of the key bytes to the number of nodes that have the key installed (Builder pattern)" )) )] - keys: IndexMap, + keys: IndexMap, /// A mapping of the value of the primary /// key bytes to the number of nodes that have the key installed. @@ -140,7 +439,7 @@ pub struct KeyResponse { doc = "Sets a mapping of the value of the primary key bytes to the number of nodes that have the key installed. (Builder pattern)" )) )] - primary_keys: IndexMap, + primary_keys: IndexMap, } /// KeyRequestOptions is used to contain optional parameters for a keyring operation diff --git a/serf-proto/src/lib.rs b/serf-proto/src/lib.rs index 8321c73..8333a6d 100644 --- a/serf-proto/src/lib.rs +++ b/serf-proto/src/lib.rs @@ -26,8 +26,8 @@ pub use leave::*; mod member; pub use member::*; -mod message; -pub use message::*; +// mod message; +// pub use message::*; mod join; pub use join::*; diff --git a/serf-proto/src/message.rs b/serf-proto/src/message.rs index f126dc0..b0bddec 100644 --- a/serf-proto/src/message.rs +++ b/serf-proto/src/message.rs @@ -1,370 +1,370 @@ -use std::sync::Arc; - -use super::{ - JoinMessage, LeaveMessage, Member, PushPullMessage, PushPullMessageRef, QueryMessage, - QueryResponseMessage, UserEventMessage, -}; - -#[cfg(feature = "encryption")] -use super::{KeyRequestMessage, KeyResponseMessage}; - -const LEAVE_MESSAGE_TAG: u8 = 0; -const JOIN_MESSAGE_TAG: u8 = 1; -const PUSH_PULL_MESSAGE_TAG: u8 = 2; -const USER_EVENT_MESSAGE_TAG: u8 = 3; -const QUERY_MESSAGE_TAG: u8 = 4; -const QUERY_RESPONSE_MESSAGE_TAG: u8 = 5; -const CONFLICT_RESPONSE_MESSAGE_TAG: u8 = 6; -const RELAY_MESSAGE_TAG: u8 = 7; -#[cfg(feature = "encryption")] -const KEY_REQUEST_MESSAGE_TAG: u8 = 253; -#[cfg(feature = "encryption")] -const KEY_RESPONSE_MESSAGE_TAG: u8 = 254; - -/// Unknown message type error -#[derive(Debug, thiserror::Error)] -#[error("unknown message type byte: {0}")] -pub struct UnknownMessageType(u8); - -impl TryFrom for MessageType { - type Error = UnknownMessageType; - - fn try_from(value: u8) -> Result { - Ok(match value { - LEAVE_MESSAGE_TAG => Self::Leave, - JOIN_MESSAGE_TAG => Self::Join, - PUSH_PULL_MESSAGE_TAG => Self::PushPull, - USER_EVENT_MESSAGE_TAG => Self::UserEvent, - QUERY_MESSAGE_TAG => Self::Query, - QUERY_RESPONSE_MESSAGE_TAG => Self::QueryResponse, - CONFLICT_RESPONSE_MESSAGE_TAG => Self::ConflictResponse, - RELAY_MESSAGE_TAG => Self::Relay, - #[cfg(feature = "encryption")] - KEY_REQUEST_MESSAGE_TAG => Self::KeyRequest, - #[cfg(feature = "encryption")] - KEY_RESPONSE_MESSAGE_TAG => Self::KeyResponse, - _ => return Err(UnknownMessageType(value)), - }) - } -} - -impl From for u8 { - fn from(value: MessageType) -> Self { - value as u8 - } -} - -/// The types of gossip messages Serf will send along -/// memberlist. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -#[repr(u8)] -#[non_exhaustive] -pub enum MessageType { - /// Leave message - Leave = LEAVE_MESSAGE_TAG, - /// Join message - Join = JOIN_MESSAGE_TAG, - /// PushPull message - PushPull = PUSH_PULL_MESSAGE_TAG, - /// UserEvent message - UserEvent = USER_EVENT_MESSAGE_TAG, - /// Query message - Query = QUERY_MESSAGE_TAG, - /// QueryResponse message - QueryResponse = QUERY_RESPONSE_MESSAGE_TAG, - /// ConflictResponse message - ConflictResponse = CONFLICT_RESPONSE_MESSAGE_TAG, - /// Relay message - Relay = RELAY_MESSAGE_TAG, - /// KeyRequest message - #[cfg(feature = "encryption")] - KeyRequest = KEY_REQUEST_MESSAGE_TAG, - /// KeyResponse message - #[cfg(feature = "encryption")] - KeyResponse = KEY_RESPONSE_MESSAGE_TAG, -} - -impl MessageType { - /// Get the string representation of the message type - #[inline] - pub const fn as_str(&self) -> &'static str { - match self { - Self::Leave => "leave", - Self::Join => "join", - Self::PushPull => "push pull", - Self::UserEvent => "user event", - Self::Query => "query", - Self::QueryResponse => "query response", - Self::ConflictResponse => "conflict response", - Self::Relay => "relay", - #[cfg(feature = "encryption")] - Self::KeyRequest => "key request", - #[cfg(feature = "encryption")] - Self::KeyResponse => "key response", - } - } -} - -/// Used to do a cheap reference to message reference conversion. -pub trait AsMessageRef { - /// Converts this type into a shared reference of the (usually inferred) input type. - fn as_message_ref(&self) -> SerfMessageRef<'_, I, A>; -} - -/// The reference type of [`SerfMessage`]. -#[derive(Debug)] -pub enum SerfMessageRef<'a, I, A> { - /// Leave message reference - Leave(&'a LeaveMessage), - /// Join message reference - Join(&'a JoinMessage), - /// PushPull message reference - PushPull(PushPullMessageRef<'a, I>), - /// UserEvent message reference - UserEvent(&'a UserEventMessage), - /// Query message reference - Query(&'a QueryMessage), - /// QueryResponse message reference - QueryResponse(&'a QueryResponseMessage), - /// ConflictResponse message reference - ConflictResponse(&'a Member), - /// KeyRequest message reference - #[cfg(feature = "encryption")] - KeyRequest(&'a KeyRequestMessage), - /// KeyResponse message reference - #[cfg(feature = "encryption")] - KeyResponse(&'a KeyResponseMessage), -} - -impl Clone for SerfMessageRef<'_, I, A> { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for SerfMessageRef<'_, I, A> {} - -impl AsMessageRef for SerfMessageRef<'_, I, A> { - fn as_message_ref(&self) -> SerfMessageRef { - *self - } -} - -/// The types of gossip messages Serf will send along -/// memberlist. -#[derive(Debug, Clone)] -pub enum SerfMessage { - /// Leave message - Leave(LeaveMessage), - /// Join message - Join(JoinMessage), - /// PushPull message - PushPull(PushPullMessage), - /// UserEvent message - UserEvent(UserEventMessage), - /// Query message - Query(QueryMessage), - /// QueryResponse message - QueryResponse(QueryResponseMessage), - /// ConflictResponse message - ConflictResponse(Member), - /// Relay message - #[cfg(feature = "encryption")] - KeyRequest(KeyRequestMessage), - /// KeyResponse message - #[cfg(feature = "encryption")] - KeyResponse(KeyResponseMessage), -} - -impl<'a, I, A> From<&'a SerfMessage> for MessageType { - fn from(msg: &'a SerfMessage) -> Self { - match msg { - SerfMessage::Leave(_) => MessageType::Leave, - SerfMessage::Join(_) => MessageType::Join, - SerfMessage::PushPull(_) => MessageType::PushPull, - SerfMessage::UserEvent(_) => MessageType::UserEvent, - SerfMessage::Query(_) => MessageType::Query, - SerfMessage::QueryResponse(_) => MessageType::QueryResponse, - SerfMessage::ConflictResponse(_) => MessageType::ConflictResponse, - #[cfg(feature = "encryption")] - SerfMessage::KeyRequest(_) => MessageType::KeyRequest, - #[cfg(feature = "encryption")] - SerfMessage::KeyResponse(_) => MessageType::KeyResponse, - } - } -} - -impl AsMessageRef for QueryMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::Query(self) - } -} - -impl AsMessageRef for QueryResponseMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::QueryResponse(self) - } -} - -impl AsMessageRef for JoinMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::Join(self) - } -} - -impl AsMessageRef for UserEventMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::UserEvent(self) - } -} - -impl AsMessageRef for &QueryMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::Query(self) - } -} - -impl AsMessageRef for &QueryResponseMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::QueryResponse(self) - } -} - -impl AsMessageRef for &JoinMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::Join(self) - } -} - -impl AsMessageRef for PushPullMessageRef<'_, I> { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::PushPull(*self) - } -} - -impl AsMessageRef for &PushPullMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::PushPull(PushPullMessageRef { - ltime: self.ltime, - status_ltimes: &self.status_ltimes, - left_members: &self.left_members, - event_ltime: self.event_ltime, - events: &self.events, - query_ltime: self.query_ltime, - }) - } -} - -impl AsMessageRef for &UserEventMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::UserEvent(self) - } -} - -impl AsMessageRef for &LeaveMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::Leave(self) - } -} - -impl AsMessageRef for &Member { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::ConflictResponse(self) - } -} - -impl AsMessageRef for &Arc> { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::ConflictResponse(self) - } -} - -#[cfg(feature = "encryption")] -impl AsMessageRef for &KeyRequestMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::KeyRequest(self) - } -} - -#[cfg(feature = "encryption")] -impl AsMessageRef for &KeyResponseMessage { - fn as_message_ref(&self) -> SerfMessageRef { - SerfMessageRef::KeyResponse(self) - } -} - -impl AsMessageRef for SerfMessage { - fn as_message_ref(&self) -> SerfMessageRef { - match self { - Self::Leave(l) => SerfMessageRef::Leave(l), - Self::Join(j) => SerfMessageRef::Join(j), - Self::PushPull(pp) => SerfMessageRef::PushPull(PushPullMessageRef { - ltime: pp.ltime, - status_ltimes: &pp.status_ltimes, - left_members: &pp.left_members, - event_ltime: pp.event_ltime, - events: &pp.events, - query_ltime: pp.query_ltime, - }), - Self::UserEvent(u) => SerfMessageRef::UserEvent(u), - Self::Query(q) => SerfMessageRef::Query(q), - Self::QueryResponse(q) => SerfMessageRef::QueryResponse(q), - Self::ConflictResponse(m) => SerfMessageRef::ConflictResponse(m), - #[cfg(feature = "encryption")] - Self::KeyRequest(kr) => SerfMessageRef::KeyRequest(kr), - #[cfg(feature = "encryption")] - Self::KeyResponse(kr) => SerfMessageRef::KeyResponse(kr), - } - } -} - -impl AsMessageRef for &SerfMessage { - fn as_message_ref(&self) -> SerfMessageRef { - match self { - SerfMessage::Leave(l) => SerfMessageRef::Leave(l), - SerfMessage::Join(j) => SerfMessageRef::Join(j), - SerfMessage::PushPull(pp) => SerfMessageRef::PushPull(PushPullMessageRef { - ltime: pp.ltime, - status_ltimes: &pp.status_ltimes, - left_members: &pp.left_members, - event_ltime: pp.event_ltime, - events: &pp.events, - query_ltime: pp.query_ltime, - }), - SerfMessage::UserEvent(u) => SerfMessageRef::UserEvent(u), - SerfMessage::Query(q) => SerfMessageRef::Query(q), - SerfMessage::QueryResponse(q) => SerfMessageRef::QueryResponse(q), - SerfMessage::ConflictResponse(m) => SerfMessageRef::ConflictResponse(m), - #[cfg(feature = "encryption")] - SerfMessage::KeyRequest(kr) => SerfMessageRef::KeyRequest(kr), - #[cfg(feature = "encryption")] - SerfMessage::KeyResponse(kr) => SerfMessageRef::KeyResponse(kr), - } - } -} - -impl core::fmt::Display for SerfMessage { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self.ty().as_str()) - } -} - -impl SerfMessage { - /// Returns the message type of this message - #[inline] - pub const fn ty(&self) -> MessageType { - match self { - Self::Leave(_) => MessageType::Leave, - Self::Join(_) => MessageType::Join, - Self::PushPull(_) => MessageType::PushPull, - Self::UserEvent(_) => MessageType::UserEvent, - Self::Query(_) => MessageType::Query, - Self::QueryResponse(_) => MessageType::QueryResponse, - Self::ConflictResponse(_) => MessageType::ConflictResponse, - #[cfg(feature = "encryption")] - Self::KeyRequest(_) => MessageType::KeyRequest, - #[cfg(feature = "encryption")] - Self::KeyResponse(_) => MessageType::KeyResponse, - } - } -} +// use std::sync::Arc; + +// use super::{ +// JoinMessage, LeaveMessage, Member, PushPullMessage, PushPullMessageRef, QueryMessage, +// QueryResponseMessage, UserEventMessage, +// }; + +// #[cfg(feature = "encryption")] +// use super::{KeyRequestMessage, KeyResponseMessage}; + +// const LEAVE_MESSAGE_TAG: u8 = 0; +// const JOIN_MESSAGE_TAG: u8 = 1; +// const PUSH_PULL_MESSAGE_TAG: u8 = 2; +// const USER_EVENT_MESSAGE_TAG: u8 = 3; +// const QUERY_MESSAGE_TAG: u8 = 4; +// const QUERY_RESPONSE_MESSAGE_TAG: u8 = 5; +// const CONFLICT_RESPONSE_MESSAGE_TAG: u8 = 6; +// const RELAY_MESSAGE_TAG: u8 = 7; +// #[cfg(feature = "encryption")] +// const KEY_REQUEST_MESSAGE_TAG: u8 = 253; +// #[cfg(feature = "encryption")] +// const KEY_RESPONSE_MESSAGE_TAG: u8 = 254; + +// /// Unknown message type error +// #[derive(Debug, thiserror::Error)] +// #[error("unknown message type byte: {0}")] +// pub struct UnknownMessageType(u8); + +// impl TryFrom for MessageType { +// type Error = UnknownMessageType; + +// fn try_from(value: u8) -> Result { +// Ok(match value { +// LEAVE_MESSAGE_TAG => Self::Leave, +// JOIN_MESSAGE_TAG => Self::Join, +// PUSH_PULL_MESSAGE_TAG => Self::PushPull, +// USER_EVENT_MESSAGE_TAG => Self::UserEvent, +// QUERY_MESSAGE_TAG => Self::Query, +// QUERY_RESPONSE_MESSAGE_TAG => Self::QueryResponse, +// CONFLICT_RESPONSE_MESSAGE_TAG => Self::ConflictResponse, +// RELAY_MESSAGE_TAG => Self::Relay, +// #[cfg(feature = "encryption")] +// KEY_REQUEST_MESSAGE_TAG => Self::KeyRequest, +// #[cfg(feature = "encryption")] +// KEY_RESPONSE_MESSAGE_TAG => Self::KeyResponse, +// _ => return Err(UnknownMessageType(value)), +// }) +// } +// } + +// impl From for u8 { +// fn from(value: MessageType) -> Self { +// value as u8 +// } +// } + +// /// The types of gossip messages Serf will send along +// /// memberlist. +// #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +// #[repr(u8)] +// #[non_exhaustive] +// pub enum MessageType { +// /// Leave message +// Leave = LEAVE_MESSAGE_TAG, +// /// Join message +// Join = JOIN_MESSAGE_TAG, +// /// PushPull message +// PushPull = PUSH_PULL_MESSAGE_TAG, +// /// UserEvent message +// UserEvent = USER_EVENT_MESSAGE_TAG, +// /// Query message +// Query = QUERY_MESSAGE_TAG, +// /// QueryResponse message +// QueryResponse = QUERY_RESPONSE_MESSAGE_TAG, +// /// ConflictResponse message +// ConflictResponse = CONFLICT_RESPONSE_MESSAGE_TAG, +// /// Relay message +// Relay = RELAY_MESSAGE_TAG, +// /// KeyRequest message +// #[cfg(feature = "encryption")] +// KeyRequest = KEY_REQUEST_MESSAGE_TAG, +// /// KeyResponse message +// #[cfg(feature = "encryption")] +// KeyResponse = KEY_RESPONSE_MESSAGE_TAG, +// } + +// impl MessageType { +// /// Get the string representation of the message type +// #[inline] +// pub const fn as_str(&self) -> &'static str { +// match self { +// Self::Leave => "leave", +// Self::Join => "join", +// Self::PushPull => "push pull", +// Self::UserEvent => "user event", +// Self::Query => "query", +// Self::QueryResponse => "query response", +// Self::ConflictResponse => "conflict response", +// Self::Relay => "relay", +// #[cfg(feature = "encryption")] +// Self::KeyRequest => "key request", +// #[cfg(feature = "encryption")] +// Self::KeyResponse => "key response", +// } +// } +// } + +// /// Used to do a cheap reference to message reference conversion. +// pub trait AsMessageRef { +// /// Converts this type into a shared reference of the (usually inferred) input type. +// fn as_message_ref(&self) -> SerfMessageRef<'_, I, A>; +// } + +// /// The reference type of [`SerfMessage`]. +// #[derive(Debug)] +// pub enum SerfMessageRef<'a, I, A> { +// /// Leave message reference +// Leave(&'a LeaveMessage), +// /// Join message reference +// Join(&'a JoinMessage), +// /// PushPull message reference +// PushPull(PushPullMessageRef<'a, I>), +// /// UserEvent message reference +// UserEvent(&'a UserEventMessage), +// /// Query message reference +// Query(&'a QueryMessage), +// /// QueryResponse message reference +// QueryResponse(&'a QueryResponseMessage), +// /// ConflictResponse message reference +// ConflictResponse(&'a Member), +// /// KeyRequest message reference +// #[cfg(feature = "encryption")] +// KeyRequest(&'a KeyRequestMessage), +// /// KeyResponse message reference +// #[cfg(feature = "encryption")] +// KeyResponse(&'a KeyResponseMessage), +// } + +// impl Clone for SerfMessageRef<'_, I, A> { +// fn clone(&self) -> Self { +// *self +// } +// } + +// impl Copy for SerfMessageRef<'_, I, A> {} + +// impl AsMessageRef for SerfMessageRef<'_, I, A> { +// fn as_message_ref(&self) -> SerfMessageRef { +// *self +// } +// } + +// /// The types of gossip messages Serf will send along +// /// memberlist. +// #[derive(Debug, Clone)] +// pub enum SerfMessage { +// /// Leave message +// Leave(LeaveMessage), +// /// Join message +// Join(JoinMessage), +// /// PushPull message +// PushPull(PushPullMessage), +// /// UserEvent message +// UserEvent(UserEventMessage), +// /// Query message +// Query(QueryMessage), +// /// QueryResponse message +// QueryResponse(QueryResponseMessage), +// /// ConflictResponse message +// ConflictResponse(Member), +// /// Relay message +// #[cfg(feature = "encryption")] +// KeyRequest(KeyRequestMessage), +// /// KeyResponse message +// #[cfg(feature = "encryption")] +// KeyResponse(KeyResponseMessage), +// } + +// impl<'a, I, A> From<&'a SerfMessage> for MessageType { +// fn from(msg: &'a SerfMessage) -> Self { +// match msg { +// SerfMessage::Leave(_) => MessageType::Leave, +// SerfMessage::Join(_) => MessageType::Join, +// SerfMessage::PushPull(_) => MessageType::PushPull, +// SerfMessage::UserEvent(_) => MessageType::UserEvent, +// SerfMessage::Query(_) => MessageType::Query, +// SerfMessage::QueryResponse(_) => MessageType::QueryResponse, +// SerfMessage::ConflictResponse(_) => MessageType::ConflictResponse, +// #[cfg(feature = "encryption")] +// SerfMessage::KeyRequest(_) => MessageType::KeyRequest, +// #[cfg(feature = "encryption")] +// SerfMessage::KeyResponse(_) => MessageType::KeyResponse, +// } +// } +// } + +// impl AsMessageRef for QueryMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::Query(self) +// } +// } + +// impl AsMessageRef for QueryResponseMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::QueryResponse(self) +// } +// } + +// impl AsMessageRef for JoinMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::Join(self) +// } +// } + +// impl AsMessageRef for UserEventMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::UserEvent(self) +// } +// } + +// impl AsMessageRef for &QueryMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::Query(self) +// } +// } + +// impl AsMessageRef for &QueryResponseMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::QueryResponse(self) +// } +// } + +// impl AsMessageRef for &JoinMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::Join(self) +// } +// } + +// impl AsMessageRef for PushPullMessageRef<'_, I> { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::PushPull(*self) +// } +// } + +// impl AsMessageRef for &PushPullMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::PushPull(PushPullMessageRef { +// ltime: self.ltime, +// status_ltimes: &self.status_ltimes, +// left_members: &self.left_members, +// event_ltime: self.event_ltime, +// events: &self.events, +// query_ltime: self.query_ltime, +// }) +// } +// } + +// impl AsMessageRef for &UserEventMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::UserEvent(self) +// } +// } + +// impl AsMessageRef for &LeaveMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::Leave(self) +// } +// } + +// impl AsMessageRef for &Member { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::ConflictResponse(self) +// } +// } + +// impl AsMessageRef for &Arc> { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::ConflictResponse(self) +// } +// } + +// #[cfg(feature = "encryption")] +// impl AsMessageRef for &KeyRequestMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::KeyRequest(self) +// } +// } + +// #[cfg(feature = "encryption")] +// impl AsMessageRef for &KeyResponseMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// SerfMessageRef::KeyResponse(self) +// } +// } + +// impl AsMessageRef for SerfMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// match self { +// Self::Leave(l) => SerfMessageRef::Leave(l), +// Self::Join(j) => SerfMessageRef::Join(j), +// Self::PushPull(pp) => SerfMessageRef::PushPull(PushPullMessageRef { +// ltime: pp.ltime, +// status_ltimes: &pp.status_ltimes, +// left_members: &pp.left_members, +// event_ltime: pp.event_ltime, +// events: &pp.events, +// query_ltime: pp.query_ltime, +// }), +// Self::UserEvent(u) => SerfMessageRef::UserEvent(u), +// Self::Query(q) => SerfMessageRef::Query(q), +// Self::QueryResponse(q) => SerfMessageRef::QueryResponse(q), +// Self::ConflictResponse(m) => SerfMessageRef::ConflictResponse(m), +// #[cfg(feature = "encryption")] +// Self::KeyRequest(kr) => SerfMessageRef::KeyRequest(kr), +// #[cfg(feature = "encryption")] +// Self::KeyResponse(kr) => SerfMessageRef::KeyResponse(kr), +// } +// } +// } + +// impl AsMessageRef for &SerfMessage { +// fn as_message_ref(&self) -> SerfMessageRef { +// match self { +// SerfMessage::Leave(l) => SerfMessageRef::Leave(l), +// SerfMessage::Join(j) => SerfMessageRef::Join(j), +// SerfMessage::PushPull(pp) => SerfMessageRef::PushPull(PushPullMessageRef { +// ltime: pp.ltime, +// status_ltimes: &pp.status_ltimes, +// left_members: &pp.left_members, +// event_ltime: pp.event_ltime, +// events: &pp.events, +// query_ltime: pp.query_ltime, +// }), +// SerfMessage::UserEvent(u) => SerfMessageRef::UserEvent(u), +// SerfMessage::Query(q) => SerfMessageRef::Query(q), +// SerfMessage::QueryResponse(q) => SerfMessageRef::QueryResponse(q), +// SerfMessage::ConflictResponse(m) => SerfMessageRef::ConflictResponse(m), +// #[cfg(feature = "encryption")] +// SerfMessage::KeyRequest(kr) => SerfMessageRef::KeyRequest(kr), +// #[cfg(feature = "encryption")] +// SerfMessage::KeyResponse(kr) => SerfMessageRef::KeyResponse(kr), +// } +// } +// } + +// impl core::fmt::Display for SerfMessage { +// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { +// write!(f, "{}", self.ty().as_str()) +// } +// } + +// impl SerfMessage { +// /// Returns the message type of this message +// #[inline] +// pub const fn ty(&self) -> MessageType { +// match self { +// Self::Leave(_) => MessageType::Leave, +// Self::Join(_) => MessageType::Join, +// Self::PushPull(_) => MessageType::PushPull, +// Self::UserEvent(_) => MessageType::UserEvent, +// Self::Query(_) => MessageType::Query, +// Self::QueryResponse(_) => MessageType::QueryResponse, +// Self::ConflictResponse(_) => MessageType::ConflictResponse, +// #[cfg(feature = "encryption")] +// Self::KeyRequest(_) => MessageType::KeyRequest, +// #[cfg(feature = "encryption")] +// Self::KeyResponse(_) => MessageType::KeyResponse, +// } +// } +// } diff --git a/serf-proto/src/push_pull.rs b/serf-proto/src/push_pull.rs index 76823c9..d676ebd 100644 --- a/serf-proto/src/push_pull.rs +++ b/serf-proto/src/push_pull.rs @@ -1,5 +1,8 @@ use indexmap::{IndexMap, IndexSet}; -use memberlist_proto::TinyVec; +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TinyVec, TupleEncoder, WireType, + utils::{merge, skip, split}, +}; use super::{LamportTime, UserEvents}; @@ -63,8 +66,8 @@ pub struct PushPullMessage { getter(const, style = "ref", attrs(doc = "Returns the recent events")), setter(attrs(doc = "Sets the recent events (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::>, TinyVec>>))] - events: TinyVec>, + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, TinyVec>))] + events: TinyVec, /// Lamport time for query clock #[viewit( getter( @@ -94,24 +97,57 @@ where } } +const LTIME_TAG: u8 = 1; +const STATUS_LTIMES_TAG: u8 = 2; +const LEFT_MEMBERS_TAG: u8 = 3; +const EVENT_LTIME_TAG: u8 = 4; +const EVENTS_TAG: u8 = 5; +const QUERY_LTIME_TAG: u8 = 6; + +const LTIME_BYTE: u8 = merge(WireType::Varint, LTIME_TAG); +const STATUS_LTIMES_BYTE: u8 = merge(WireType::LengthDelimited, STATUS_LTIMES_TAG); +const LEFT_MEMBERS_BYTE: u8 = merge(WireType::LengthDelimited, LEFT_MEMBERS_TAG); +const EVENT_LTIME_BYTE: u8 = merge(WireType::Varint, EVENT_LTIME_TAG); +const EVENTS_BYTE: u8 = merge(WireType::LengthDelimited, EVENTS_TAG); +const QUERY_LTIME_BYTE: u8 = merge(WireType::Varint, QUERY_LTIME_TAG); + /// Used when doing a state exchange. This /// is a relatively large message, but is sent infrequently -#[viewit::viewit(getters(skip), setters(skip))] +#[viewit::viewit(vis_all = "", getters(vis_all = "pub"), setters(skip))] #[derive(Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize))] pub struct PushPullMessageRef<'a, I> { /// Current node lamport time + #[viewit(getter(const, style = "move", attrs(doc = "Returns the lamport time")))] ltime: LamportTime, /// Maps the node to its status time - status_ltimes: &'a IndexMap, + #[viewit(getter( + const, + style = "ref", + attrs(doc = "Returns the maps the node to its status time") + ))] + status_ltimes: RepeatedDecoder<'a>, /// List of left nodes - left_members: &'a IndexSet, + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the list of left nodes")))] + left_members: RepeatedDecoder<'a>, /// Lamport time for event clock + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for event clock") + ))] event_ltime: LamportTime, /// Recent events - events: &'a [Option], + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the recent events")))] + events: RepeatedDecoder<'a>, /// Lamport time for query clock + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for query clock") + ))] query_ltime: LamportTime, + #[viewit(getter(skip))] + _m: core::marker::PhantomData, } impl Clone for PushPullMessageRef<'_, I> { @@ -122,30 +158,300 @@ impl Clone for PushPullMessageRef<'_, I> { impl Copy for PushPullMessageRef<'_, I> {} -impl<'a, I> From<&'a PushPullMessage> for PushPullMessageRef<'a, I> { - #[inline] - fn from(msg: &'a PushPullMessage) -> Self { - Self { - ltime: msg.ltime, - status_ltimes: &msg.status_ltimes, - left_members: &msg.left_members, - event_ltime: msg.event_ltime, - events: &msg.events, - query_ltime: msg.query_ltime, +impl<'a, I> DataRef<'a, PushPullMessage> for PushPullMessageRef<'a, I::Ref<'a>> +where + I: Data + Eq + core::hash::Hash, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut ltime = None; + let mut status_ltimes_offsets = None; + let mut num_status_ltimes = 0; + let mut left_members_offsets = None; + let mut num_left_members = 0; + let mut event_ltime = None; + let mut events_offsets = None; + let mut num_events = 0; + let mut query_ltime = None; + + while offset < buf_len { + match buf[offset] { + LTIME_BYTE => { + if ltime.is_some() { + return Err(DecodeError::duplicate_field( + "PushPullMessage", + "ltime", + LTIME_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + ltime = Some(v); + } + STATUS_LTIMES_BYTE => { + let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = status_ltimes_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + status_ltimes_offsets = Some((offset - 1, offset + readed)); + } + num_status_ltimes += 1; + offset += readed; + } + LEFT_MEMBERS_BYTE => { + let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = left_members_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + left_members_offsets = Some((offset - 1, offset + readed)); + } + num_left_members += 1; + offset += readed; + } + EVENT_LTIME_BYTE => { + if event_ltime.is_some() { + return Err(DecodeError::duplicate_field( + "PushPullMessage", + "event_ltime", + EVENT_LTIME_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + event_ltime = Some(v); + } + EVENTS_BYTE => { + let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = events_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + events_offsets = Some((offset - 1, offset + readed)); + } + num_events += 1; + offset += readed; + } + QUERY_LTIME_BYTE => { + if query_ltime.is_some() { + return Err(DecodeError::duplicate_field( + "PushPullMessage", + "query_ltime", + QUERY_LTIME_TAG, + )); + } + + offset += 1; + let (o, v) = >::decode(&buf[offset..])?; + offset += o; + query_ltime = Some(v); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } } + + Ok(( + offset, + Self { + ltime: ltime.ok_or_else(|| DecodeError::missing_field("PushPullMessage", "ltime"))?, + status_ltimes: if let Some((start, end)) = events_offsets { + RepeatedDecoder::new(STATUS_LTIMES_TAG, WireType::LengthDelimited, buf) + .with_nums(num_status_ltimes) + .with_offsets(start, end) + } else { + RepeatedDecoder::new(STATUS_LTIMES_TAG, WireType::LengthDelimited, buf) + }, + left_members: if let Some((start, end)) = events_offsets { + RepeatedDecoder::new(LEFT_MEMBERS_TAG, WireType::LengthDelimited, buf) + .with_nums(num_left_members) + .with_offsets(start, end) + } else { + RepeatedDecoder::new(LEFT_MEMBERS_TAG, WireType::LengthDelimited, buf) + }, + event_ltime: event_ltime + .ok_or_else(|| DecodeError::missing_field("PushPullMessage", "event_ltime"))?, + events: if let Some((start, end)) = events_offsets { + RepeatedDecoder::new(EVENTS_TAG, WireType::LengthDelimited, buf) + .with_nums(num_events) + .with_offsets(start, end) + } else { + RepeatedDecoder::new(EVENTS_TAG, WireType::LengthDelimited, buf) + }, + query_ltime: query_ltime + .ok_or_else(|| DecodeError::missing_field("PushPullMessage", "query_ltime"))?, + _m: core::marker::PhantomData, + }, + )) } } -impl<'a, I> From<&'a mut PushPullMessage> for PushPullMessageRef<'a, I> { - #[inline] - fn from(msg: &'a mut PushPullMessage) -> Self { - Self { - ltime: msg.ltime, - status_ltimes: &msg.status_ltimes, - left_members: &msg.left_members, - event_ltime: msg.event_ltime, - events: &msg.events, - query_ltime: msg.query_ltime, +impl Data for PushPullMessage +where + I: Data + Eq + core::hash::Hash, +{ + type Ref<'a> = PushPullMessageRef<'a, I::Ref<'a>>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + let left_members = val + .left_members + .iter::() + .map(|res| res.and_then(Data::from_ref)) + .collect::, DecodeError>>()?; + + Ok(Self { + ltime: val.ltime, + status_ltimes: val + .status_ltimes + .iter::<(I, LamportTime)>() + .map(|res| res.and_then(Data::from_ref)) + .collect::, DecodeError>>()?, + left_members, + event_ltime: val.event_ltime, + events: val + .events + .iter::() + .map(|res| res.and_then(Data::from_ref)) + .collect::, DecodeError>>()?, + query_ltime: val.query_ltime, + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 0usize; + + len += 1 + self.ltime.encoded_len(); + + len += self + .status_ltimes + .iter() + .map(|(k, v)| 1 + TupleEncoder::new(k, v).encoded_len_with_length_delimited()) + .sum::(); + + len += self + .left_members + .iter() + .map(|id| 1 + id.encoded_len_with_length_delimited()) + .sum::(); + len += 1 + self.event_ltime.encoded_len(); + len += 1 + + self + .events + .iter() + .map(|e| 1 + e.encoded_len_with_length_delimited()) + .sum::(); + len += 1 + self.query_ltime.encoded_len(); + + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + } + }; } + + let mut offset = 0; + let buf_len = buf.len(); + + bail!(self(offset, buf_len)); + buf[offset] = LTIME_BYTE; + offset += 1; + offset += self.ltime.encode(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = STATUS_LTIMES_BYTE; + offset += 1; + + self + .status_ltimes + .iter() + .try_fold(&mut offset, |off, (k, v)| { + bail!(self(*off, buf_len)); + buf[*off] = LEFT_MEMBERS_BYTE; + *off += 1; + *off += TupleEncoder::new(k, v).encode_with_length_delimited(&mut buf[*off..])?; + Ok(off) + }) + .map_err(|e: EncodeError| e.update(self.encoded_len(), buf_len))?; + + self + .left_members + .iter() + .try_fold(&mut offset, |off, id| { + bail!(self(*off, buf_len)); + buf[*off] = LEFT_MEMBERS_BYTE; + *off += 1; + *off += id.encode_length_delimited(&mut buf[*off..])?; + Ok(off) + }) + .map_err(|e: EncodeError| e.update(self.encoded_len(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = EVENT_LTIME_BYTE; + offset += 1; + offset += self.event_ltime.encode(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = EVENTS_BYTE; + offset += 1; + + self + .events + .iter() + .try_fold(&mut offset, |off, e| { + bail!(self(*off, buf_len)); + buf[*off] = EVENTS_BYTE; + *off += 1; + *off += e.encode_length_delimited(&mut buf[*off..])?; + Ok(off) + }) + .map_err(|e: EncodeError| e.update(self.encoded_len(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = QUERY_LTIME_BYTE; + offset += 1; + offset += self.query_ltime.encode(&mut buf[offset..])?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) } } diff --git a/serf-proto/src/tags.rs b/serf-proto/src/tags.rs index 5835e17..607c492 100644 --- a/serf-proto/src/tags.rs +++ b/serf-proto/src/tags.rs @@ -1,6 +1,6 @@ use indexmap::IndexMap; use memberlist_proto::{ - Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, WireType, + Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TupleEncoder, WireType, utils::{merge, skip, split}, }; use smol_str::SmolStr; @@ -136,8 +136,8 @@ impl Data for Tags { { val .src - .iter::() - .map(|res| res.and_then(|t| Tag::from_ref(t).map(Tag::split))) + .iter::<(SmolStr, SmolStr)>() + .map(|res| res.and_then(Data::from_ref)) .collect::, DecodeError>>() .map(Self) } @@ -146,7 +146,7 @@ impl Data for Tags { self .0 .iter() - .map(|(k, v)| 1 + TagRef::new(k, v).encoded_len_with_length_delimited()) + .map(|(k, v)| 1 + TupleEncoder::new(k, v).encoded_len_with_length_delimited()) .sum::() } @@ -162,219 +162,219 @@ impl Data for Tags { buf[offset] = TAGS_BYTE; offset += 1; - offset += TagRef::new(k, v).encode_with_length_delimited(&mut buf[offset..])?; + offset += TupleEncoder::new(k, v).encode_with_length_delimited(&mut buf[offset..])?; Ok(offset) }) .map_err(|e: EncodeError| e.update(self.encoded_len(), buf.len())) } } -#[derive(Debug)] -struct Tag { - key: SmolStr, - value: SmolStr, -} - -impl Tag { - fn split(self) -> (SmolStr, SmolStr) { - (self.key, self.value) - } -} - -impl Data for Tag { - type Ref<'a> = TagRef<'a>; - - fn from_ref(val: Self::Ref<'_>) -> Result - where - Self: Sized, - { - Ok(Self { - key: SmolStr::new(val.key), - value: SmolStr::new(val.value), - }) - } - - fn encoded_len(&self) -> usize { - TagRef::new(&self.key, &self.value).encoded_len() - } - - fn encode(&self, buf: &mut [u8]) -> Result { - TagRef::new(&self.key, &self.value).encode(buf) - } -} - -/// A reference to a (key, value) pair of a tag -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct TagRef<'a> { - key: &'a str, - value: &'a str, -} - -impl<'a> DataRef<'a, Tag> for TagRef<'a> { - fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> - where - Self: Sized, - { - let mut offset = 0; - let buf_len = src.len(); - - let mut key = None; - let mut val = None; - - while offset < buf_len { - match src[offset] { - Self::KEY_BYTE => { - if key.is_some() { - return Err(DecodeError::duplicate_field("Tag", "key", Self::KEY_TAG)); - } - offset += 1; - - let (read, value) = - <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; - key = Some(value); - offset += read; - } - Self::VALUE_BYTE => { - if val.is_some() { - return Err(DecodeError::duplicate_field( - "Tag", - "value", - Self::VALUE_TAG, - )); - } - offset += 1; - - let (read, value) = - <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; - val = Some(value); - offset += read; - } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; - offset += skip(wire_type, &src[offset..])?; - } - } - } - - Ok(( - offset, - Self { - key: key.unwrap_or(""), - value: val.unwrap_or(""), - }, - )) - } -} - -impl<'a> TagRef<'a> { - const KEY_TAG: u8 = 1; - const KEY_BYTE: u8 = merge(WireType::LengthDelimited, Self::KEY_TAG); - const VALUE_TAG: u8 = 2; - const VALUE_BYTE: u8 = merge(WireType::LengthDelimited, Self::VALUE_TAG); - - fn new(key: &'a str, value: &'a str) -> Self { - Self { key, value } - } - - fn encoded_len(&self) -> usize { - let klen = self.key.len(); - let vlen = self.value.len(); - - let mut len = 0; - if klen != 0 { - len += 1 + (klen as u32).encoded_len(); - } - - if vlen != 0 { - len += 1 + (vlen as u32).encoded_len(); - } - - len - } - - fn encoded_len_with_length_delimited(&self) -> usize { - let len = self.encoded_len(); - len + (len as u32).encoded_len() - } - - fn encode(&self, buf: &mut [u8]) -> Result { - let buf_len = buf.len(); - let mut offset = 0; - - if buf_len <= offset { - return Err(EncodeError::insufficient_buffer( - self.encoded_len(), - buf_len, - )); - } - - let klen = self.key.len(); - if klen != 0 { - buf[offset] = Self::KEY_BYTE; - offset += 1; - - offset += (klen as u32) - .encode(&mut buf[offset..]) - .map_err(|e| e.update(self.encoded_len(), buf_len))?; - if buf_len < offset + klen { - return Err(EncodeError::insufficient_buffer( - self.encoded_len(), - buf_len, - )); - } - buf[offset..offset + klen].copy_from_slice(self.key.as_bytes()); - offset += klen; - } - - if buf_len <= offset { - return Err(EncodeError::insufficient_buffer( - self.encoded_len(), - buf_len, - )); - } - - let vlen = self.value.len(); - if vlen != 0 { - buf[offset] = Self::VALUE_BYTE; - offset += 1; - - offset += (vlen as u32) - .encode(&mut buf[offset..]) - .map_err(|e| e.update(self.encoded_len(), buf_len))?; - if buf_len < offset + vlen { - return Err(EncodeError::insufficient_buffer( - self.encoded_len(), - buf_len, - )); - } - - buf[offset..offset + vlen].copy_from_slice(self.value.as_bytes()); - offset += vlen; - } - - #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len()); - - Ok(offset) - } - - fn encode_with_length_delimited(&self, buf: &mut [u8]) -> Result { - let len = self.encoded_len(); - let buf_len = buf.len(); - if buf_len < len { - return Err(EncodeError::insufficient_buffer(len, buf_len)); - } - - let mut offset = 0; - offset += (len as u32).encode(&mut buf[offset..])?; - offset += self.encode(&mut buf[offset..])?; - - #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len_with_length_delimited()); - - Ok(offset) - } -} +// #[derive(Debug)] +// struct Tag { +// key: SmolStr, +// value: SmolStr, +// } + +// impl Tag { +// fn split(self) -> (SmolStr, SmolStr) { +// (self.key, self.value) +// } +// } + +// impl Data for Tag { +// type Ref<'a> = TagRef<'a>; + +// fn from_ref(val: Self::Ref<'_>) -> Result +// where +// Self: Sized, +// { +// Ok(Self { +// key: SmolStr::new(val.key), +// value: SmolStr::new(val.value), +// }) +// } + +// fn encoded_len(&self) -> usize { +// TagRef::new(&self.key, &self.value).encoded_len() +// } + +// fn encode(&self, buf: &mut [u8]) -> Result { +// TagRef::new(&self.key, &self.value).encode(buf) +// } +// } + +// /// A reference to a (key, value) pair of a tag +// #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// pub struct TagRef<'a> { +// key: &'a str, +// value: &'a str, +// } + +// impl<'a> DataRef<'a, Tag> for TagRef<'a> { +// fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> +// where +// Self: Sized, +// { +// let mut offset = 0; +// let buf_len = src.len(); + +// let mut key = None; +// let mut val = None; + +// while offset < buf_len { +// match src[offset] { +// Self::KEY_BYTE => { +// if key.is_some() { +// return Err(DecodeError::duplicate_field("Tag", "key", Self::KEY_TAG)); +// } +// offset += 1; + +// let (read, value) = +// <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; +// key = Some(value); +// offset += read; +// } +// Self::VALUE_BYTE => { +// if val.is_some() { +// return Err(DecodeError::duplicate_field( +// "Tag", +// "value", +// Self::VALUE_TAG, +// )); +// } +// offset += 1; + +// let (read, value) = +// <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; +// val = Some(value); +// offset += read; +// } +// other => { +// offset += 1; + +// let (wire_type, _) = split(other); +// let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; +// offset += skip(wire_type, &src[offset..])?; +// } +// } +// } + +// Ok(( +// offset, +// Self { +// key: key.unwrap_or(""), +// value: val.unwrap_or(""), +// }, +// )) +// } +// } + +// impl<'a> TagRef<'a> { +// const KEY_TAG: u8 = 1; +// const KEY_BYTE: u8 = merge(WireType::LengthDelimited, Self::KEY_TAG); +// const VALUE_TAG: u8 = 2; +// const VALUE_BYTE: u8 = merge(WireType::LengthDelimited, Self::VALUE_TAG); + +// fn new(key: &'a str, value: &'a str) -> Self { +// Self { key, value } +// } + +// fn encoded_len(&self) -> usize { +// let klen = self.key.len(); +// let vlen = self.value.len(); + +// let mut len = 0; +// if klen != 0 { +// len += 1 + (klen as u32).encoded_len(); +// } + +// if vlen != 0 { +// len += 1 + (vlen as u32).encoded_len(); +// } + +// len +// } + +// fn encoded_len_with_length_delimited(&self) -> usize { +// let len = self.encoded_len(); +// len + (len as u32).encoded_len() +// } + +// fn encode(&self, buf: &mut [u8]) -> Result { +// let buf_len = buf.len(); +// let mut offset = 0; + +// if buf_len <= offset { +// return Err(EncodeError::insufficient_buffer( +// self.encoded_len(), +// buf_len, +// )); +// } + +// let klen = self.key.len(); +// if klen != 0 { +// buf[offset] = Self::KEY_BYTE; +// offset += 1; + +// offset += (klen as u32) +// .encode(&mut buf[offset..]) +// .map_err(|e| e.update(self.encoded_len(), buf_len))?; +// if buf_len < offset + klen { +// return Err(EncodeError::insufficient_buffer( +// self.encoded_len(), +// buf_len, +// )); +// } +// buf[offset..offset + klen].copy_from_slice(self.key.as_bytes()); +// offset += klen; +// } + +// if buf_len <= offset { +// return Err(EncodeError::insufficient_buffer( +// self.encoded_len(), +// buf_len, +// )); +// } + +// let vlen = self.value.len(); +// if vlen != 0 { +// buf[offset] = Self::VALUE_BYTE; +// offset += 1; + +// offset += (vlen as u32) +// .encode(&mut buf[offset..]) +// .map_err(|e| e.update(self.encoded_len(), buf_len))?; +// if buf_len < offset + vlen { +// return Err(EncodeError::insufficient_buffer( +// self.encoded_len(), +// buf_len, +// )); +// } + +// buf[offset..offset + vlen].copy_from_slice(self.value.as_bytes()); +// offset += vlen; +// } + +// #[cfg(debug_assertions)] +// super::debug_assert_write_eq(offset, self.encoded_len()); + +// Ok(offset) +// } + +// fn encode_with_length_delimited(&self, buf: &mut [u8]) -> Result { +// let len = self.encoded_len(); +// let buf_len = buf.len(); +// if buf_len < len { +// return Err(EncodeError::insufficient_buffer(len, buf_len)); +// } + +// let mut offset = 0; +// offset += (len as u32).encode(&mut buf[offset..])?; +// offset += self.encode(&mut buf[offset..])?; + +// #[cfg(debug_assertions)] +// super::debug_assert_write_eq(offset, self.encoded_len_with_length_delimited()); + +// Ok(offset) +// } +// } From fc738dc0c7d49cf43d71621a8ffa5d29fb00846b Mon Sep 17 00:00:00 2001 From: al8n Date: Wed, 26 Feb 2025 23:47:45 +0800 Subject: [PATCH 05/39] WIP --- README.md | 4 +- serf-core/src/coalesce/member.rs | 4 +- serf-core/src/coalesce/user.rs | 6 +- serf-core/src/coordinate.rs | 146 +--------- serf-core/src/delegate/composite.rs | 2 +- serf-core/src/delegate/merge.rs | 2 +- serf-core/src/error.rs | 167 +----------- serf-core/src/event.rs | 8 +- serf-core/src/event/crate_event.rs | 4 +- serf-core/src/key_manager.rs | 8 +- serf-core/src/lib.rs | 254 +++++++++--------- serf-core/src/options.rs | 2 +- serf-core/src/serf.rs | 2 +- serf-core/src/serf/api.rs | 14 +- serf-core/src/serf/base.rs | 26 +- serf-core/src/serf/base/tests.rs | 14 +- serf-core/src/serf/base/tests/serf.rs | 6 +- .../src/serf/base/tests/serf/delegate.rs | 8 +- serf-core/src/serf/base/tests/serf/event.rs | 6 +- serf-core/src/serf/base/tests/serf/join.rs | 4 +- serf-core/src/serf/delegate.rs | 36 +-- serf-core/src/serf/internal_query.rs | 24 +- serf-core/src/serf/query.rs | 31 +-- serf-core/src/snapshot.rs | 13 +- serf-core/src/types/member.rs | 4 +- 25 files changed, 236 insertions(+), 559 deletions(-) diff --git a/README.md b/README.md index 074e717..57552f7 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ Here are the layers: Used to involve a client in a potential cluster merge operation. Namely, when a node does a promised push/pull (as part of a join), the delegate is involved and allowed to cancel the join based on custom logic. The merge delegate is NOT invoked as part of the push-pull anti-entropy. - - **`TransformDelegate`** + - **``** A delegate for encoding and decoding. Used to control how `serf` should encode/decode messages. @@ -118,7 +118,7 @@ serf = "0.2" - ***Does Rust's serf implemenetation compatible to Go's serf?*** - No but yes! By default, it is not compatible. But the secret is the serialize/deserilize layer, Go's serf use the msgpack as the serialization/deserialization framework, so in theory, if you can implement a [`TransformDelegate`](https://docs.rs/serf-core/transport/trait.TransformDelegate.html) trait which compat to Go's serf, then it becomes compatible. + No but yes! By default, it is not compatible. But the secret is the serialize/deserilize layer, Go's serf use the msgpack as the serialization/deserialization framework, so in theory, if you can implement a [``](https://docs.rs/serf-core/transport/trait..html) trait which compat to Go's serf, then it becomes compatible. - ***If Go's serf adds more functionalities, will this project also support?*** diff --git a/serf-core/src/coalesce/member.rs b/serf-core/src/coalesce/member.rs index 54b356d..c439c8a 100644 --- a/serf-core/src/coalesce/member.rs +++ b/serf-core/src/coalesce/member.rs @@ -4,7 +4,7 @@ use async_channel::Sender; use memberlist_core::{ CheapClone, transport::{AddressResolver, Node, Transport}, - types::TinyVec, + proto::TinyVec, }; use crate::{ @@ -126,7 +126,7 @@ mod tests { use futures::FutureExt; use memberlist_core::{ agnostic_lite::{RuntimeLite, tokio::TokioRuntime}, - transport::{Lpe, resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport}, + transport::resolver::socket_addr::SocketAddrResolver, }; use serf_proto::{MemberStatus, UserEventMessage}; use smol_str::SmolStr; diff --git a/serf-core/src/coalesce/user.rs b/serf-core/src/coalesce/user.rs index de82e50..0957fb0 100644 --- a/serf-core/src/coalesce/user.rs +++ b/serf-core/src/coalesce/user.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use indexmap::IndexMap; -use memberlist_core::types::TinyVec; +use memberlist_core::proto::TinyVec; use serf_proto::UserEventMessage; use smol_str::SmolStr; @@ -102,9 +102,7 @@ mod tests { use std::net::SocketAddr; use agnostic_lite::tokio::TokioRuntime; - use memberlist_core::transport::{ - Lpe, resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport, - }; + use memberlist_core::transport::resolver::socket_addr::SocketAddrResolver; use crate::{ DefaultDelegate, diff --git a/serf-core/src/coordinate.rs b/serf-core/src/coordinate.rs index 096fc9b..6a8a755 100644 --- a/serf-core/src/coordinate.rs +++ b/serf-core/src/coordinate.rs @@ -4,11 +4,9 @@ use std::{ time::Duration, }; -use byteorder::{ByteOrder, NetworkEndian}; use memberlist_core::CheapClone; use parking_lot::RwLock; use rand::Rng; -use serf_proto::Transformable; use smallvec::SmallVec; /// Used to convert float seconds to nanoseconds. @@ -184,7 +182,7 @@ pub struct CoordinateOptions { doc = "Sets the metric labels used to identify the metrics for this coordinate client." )) )] - metric_labels: std::sync::Arc, + metric_labels: std::sync::Arc, } impl Default for CoordinateOptions { @@ -208,7 +206,7 @@ impl CoordinateOptions { latency_filter_size: 3, gravity_rho: 150.0, #[cfg(feature = "metrics")] - metric_labels: std::sync::Arc::new(memberlist_core::types::MetricLabels::default()), + metric_labels: std::sync::Arc::new(memberlist_core::proto::MetricLabels::default()), } } } @@ -649,101 +647,6 @@ impl Coordinate { } } -/// The error when encoding or decoding a coordinate. -#[derive(Debug, thiserror::Error)] -pub enum CoordinateTransformError { - /// Returned when the buffer is too small to encode the coordinate. - #[error("encode buffer too small")] - BufferTooSmall, - /// Returned when there are not enough bytes to decode the coordinate. - #[error("not enough bytes to decode")] - NotEnoughBytes, -} - -impl Transformable for Coordinate { - type Error = CoordinateTransformError; - - fn encode(&self, dst: &mut [u8]) -> Result { - let encoded_len = self.encoded_len(); - if dst.len() < encoded_len { - return Err(Self::Error::BufferTooSmall); - } - - let mut offset = 0; - NetworkEndian::write_u32(&mut dst[offset..], encoded_len as u32); - offset += 4; - NetworkEndian::write_f64(&mut dst[offset..], self.error); - offset += 8; - NetworkEndian::write_f64(&mut dst[offset..], self.adjustment); - offset += 8; - NetworkEndian::write_f64(&mut dst[offset..], self.height); - offset += 8; - for f in &self.portion { - NetworkEndian::write_f64(&mut dst[offset..], *f); - offset += 8; - } - - debug_assert_eq!( - offset, encoded_len, - "expect write {} bytes, but actual write {} bytes", - encoded_len, offset - ); - - Ok(offset) - } - - fn encoded_len(&self) -> usize { - 4 + 8 * self.portion.len() + 8 * 3 - } - - fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error> - where - Self: Sized, - { - let src_len = src.len(); - if src_len < 4 + 3 * 8 { - return Err(Self::Error::NotEnoughBytes); - } - - let len = NetworkEndian::read_u32(&src[0..4]) as usize; - if src_len < len { - return Err(Self::Error::NotEnoughBytes); - } - - let mut offset = 4; - let error = NetworkEndian::read_f64(&src[offset..]); - offset += 8; - let adjustment = NetworkEndian::read_f64(&src[offset..]); - offset += 8; - let height = NetworkEndian::read_f64(&src[offset..]); - offset += 8; - - let num_portion = (len - 4 - 3 * 8) / 8; - let mut portion = SmallVec::with_capacity(num_portion); - - for _ in 0..num_portion { - portion.push(NetworkEndian::read_f64(&src[offset..])); - offset += 8; - } - - debug_assert_eq!( - offset, len, - "expect read {} bytes, but actual read {} bytes", - len, offset - ); - - Ok(( - len, - Self { - portion, - error, - adjustment, - height, - }, - )) - } -} - #[inline] fn add_in_place(vec1: &mut [f64], vec2: &[f64]) { for (x, y) in vec1.iter_mut().zip(vec2.iter()) { @@ -810,9 +713,9 @@ fn unit_vector_at(vec1: &[f64], vec2: &[f64]) -> (SmallVec<[f64; DEFAULT_DIMENSI } fn rand_f64() -> f64 { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); loop { - let f = (rng.gen_range::(0..(1u64 << 63u64)) as f64) / ((1u64 << 63u64) as f64); + let f = (rng.random_range::(0..(1u64 << 63u64)) as f64) / ((1u64 << 63u64) as f64); if f == 1.0 { continue; } @@ -842,47 +745,6 @@ mod tests { } } - impl Coordinate { - fn random(size: usize) -> Self { - let mut portion = SmallVec::with_capacity(size); - for _ in 0..size { - portion.push(rand_f64()); - } - Self { - portion, - error: rand_f64(), - adjustment: rand_f64(), - height: rand_f64(), - } - } - } - - #[tokio::test] - async fn test_transform_encode_decode() { - for i in 0..100 { - let filter = Coordinate::random(i); - let mut buf = vec![0; filter.encoded_len()]; - let encoded_len = filter.encode(&mut buf).unwrap(); - assert_eq!(encoded_len, filter.encoded_len()); - - let (decoded_len, decoded) = Coordinate::decode(&buf).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - Coordinate::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - - let (decoded_len, decoded) = - Coordinate::decode_from_async_reader(&mut futures::io::Cursor::new(&buf)) - .await - .unwrap(); - assert_eq!(decoded_len, encoded_len); - assert_eq!(decoded, filter); - } - } - #[test] fn test_client_update() { let cfg = CoordinateOptions::default().with_dimensionality(3); diff --git a/serf-core/src/delegate/composite.rs b/serf-core/src/delegate/composite.rs index 39926fe..78cdee6 100644 --- a/serf-core/src/delegate/composite.rs +++ b/serf-core/src/delegate/composite.rs @@ -12,7 +12,7 @@ use crate::{ use super::{ DefaultMergeDelegate, Delegate, LpeTransfromDelegate, MergeDelegate, NoopReconnectDelegate, - ReconnectDelegate, TransformDelegate, + ReconnectDelegate, , }; /// `CompositeDelegate` is a helpful struct to split the [`Delegate`] into multiple small delegates, diff --git a/serf-core/src/delegate/merge.rs b/serf-core/src/delegate/merge.rs index 004edde..b2b2938 100644 --- a/serf-core/src/delegate/merge.rs +++ b/serf-core/src/delegate/merge.rs @@ -1,4 +1,4 @@ -use memberlist_core::{CheapClone, transport::Id, types::TinyVec}; +use memberlist_core::{CheapClone, transport::Id, proto::TinyVec}; use std::future::Future; use crate::types::Member; diff --git a/serf-core/src/error.rs b/serf-core/src/error.rs index bfed1c0..8e5d84d 100644 --- a/serf-core/src/error.rs +++ b/serf-core/src/error.rs @@ -1,77 +1,18 @@ -use std::{borrow::Cow, collections::HashMap}; +use std::collections::HashMap; use memberlist_core::{ - delegate::DelegateError as MemberlistDelegateError, transport::{AddressResolver, MaybeResolvedAddress, Node, Transport}, - types::{SmallVec, TinyVec}, + proto::{SmallVec, TinyVec}, }; -use smol_str::SmolStr; use crate::{ - delegate::{Delegate, MergeDelegate, TransformDelegate}, + delegate::Delegate, serf::{SerfDelegate, SerfState}, types::Member, }; pub use crate::snapshot::SnapshotError; -/// Error trait for [`Delegate`] -#[derive(thiserror::Error)] -pub enum SerfDelegateError { - /// Serf error - #[error(transparent)] - Serf(#[from] SerfError), - /// [`TransformDelegate`] error - #[error(transparent)] - TransformDelegate(::Error), - /// [`MergeDelegate`] error - #[error(transparent)] - MergeDelegate(::Error), -} - -impl core::fmt::Debug for SerfDelegateError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::TransformDelegate(err) => write!(f, "{err:?}"), - Self::MergeDelegate(err) => write!(f, "{err:?}"), - Self::Serf(err) => write!(f, "{err:?}"), - } - } -} - -impl SerfDelegateError { - /// Create a delegate error from an alive delegate error. - #[inline] - pub const fn transform(err: ::Error) -> Self { - Self::TransformDelegate(err) - } - - /// Create a delegate error from a merge delegate error. - #[inline] - pub const fn merge(err: ::Error) -> Self { - Self::MergeDelegate(err) - } - - /// Create a delegate error from a serf error. - #[inline] - pub const fn serf(err: crate::error::SerfError) -> Self { - Self::Serf(err) - } -} - -impl From>> for SerfDelegateError -where - D: Delegate::ResolvedAddress>, - T: Transport, -{ - fn from(value: MemberlistDelegateError>) -> Self { - match value { - MemberlistDelegateError::AliveDelegate(e) => e, - MemberlistDelegateError::MergeDelegate(e) => e, - } - } -} - /// Error type for the serf crate. #[derive(thiserror::Error)] pub enum Error @@ -81,53 +22,15 @@ where { /// Returned when the underlyhing memberlist error #[error(transparent)] - Memberlist(#[from] MemberlistError::ResolvedAddress>), + Memberlist(#[from] memberlist_core::error::Error>), /// Returned when the serf error #[error(transparent)] Serf(#[from] SerfError), - /// Returned when the transport error - #[error(transparent)] - Transport(T::Error), - /// Returned when the delegate error - #[error(transparent)] - Delegate(#[from] SerfDelegateError), /// Returned when the relay error #[error(transparent)] Relay(#[from] RelayError), } -impl From>> for Error -where - D: Delegate::ResolvedAddress>, - T: Transport, -{ - fn from(value: memberlist_core::error::Error>) -> Self { - match value { - memberlist_core::error::Error::NotRunning => Self::Memberlist(MemberlistError::NotRunning), - memberlist_core::error::Error::UpdateTimeout => { - Self::Memberlist(MemberlistError::UpdateTimeout) - } - memberlist_core::error::Error::LeaveTimeout => { - Self::Memberlist(MemberlistError::LeaveTimeout) - } - memberlist_core::error::Error::Lost(n) => Self::Memberlist(MemberlistError::Lost(n)), - memberlist_core::error::Error::Delegate(e) => match e.into() { - SerfDelegateError::Serf(e) => Self::Serf(e), - e => Self::Delegate(e), - }, - memberlist_core::error::Error::Transport(e) => Self::Transport(e), - memberlist_core::error::Error::UnexpectedMessage { expected, got } => { - Self::Memberlist(MemberlistError::UnexpectedMessage { expected, got }) - } - memberlist_core::error::Error::SequenceNumberMismatch { ping, ack } => { - Self::Memberlist(MemberlistError::SequenceNumberMismatch { ping, ack }) - } - memberlist_core::error::Error::Remote(e) => Self::Memberlist(MemberlistError::Remote(e)), - memberlist_core::error::Error::Other(e) => Self::Memberlist(MemberlistError::Other(e)), - } - } -} - impl core::fmt::Debug for Error where D: Delegate::ResolvedAddress>, @@ -137,8 +40,6 @@ where match self { Self::Memberlist(e) => write!(f, "{e:?}"), Self::Serf(e) => write!(f, "{e:?}"), - Self::Transport(e) => write!(f, "{e:?}"), - Self::Delegate(e) => write!(f, "{e:?}"), Self::Relay(e) => write!(f, "{e:?}"), } } @@ -159,18 +60,6 @@ where D: Delegate::ResolvedAddress>, T: Transport, { - /// Create error from a transform error - #[inline] - pub fn transform_delegate(err: ::Error) -> Self { - Self::Delegate(SerfDelegateError::TransformDelegate(err)) - } - - /// Create a merge delegate error - #[inline] - pub const fn merge_delegate(err: ::Error) -> Self { - Self::Delegate(SerfDelegateError::MergeDelegate(err)) - } - /// Create a query response too large error #[inline] pub const fn query_response_too_large(limit: usize, got: usize) -> Self { @@ -261,14 +150,6 @@ where Self::Serf(SerfError::Snapshot(err)) } - /// Create a memberlist error - #[inline] - pub const fn memberlist( - err: MemberlistError::ResolvedAddress>, - ) -> Self { - Self::Memberlist(err) - } - /// Create a bad leave status error #[inline] pub const fn bad_leave_status(status: SerfState) -> Self { @@ -349,46 +230,6 @@ pub enum SerfError { BroadcastChannelClosed, } -/// Error type for [`Memberlist`](memberlist_core::Memberlist). -#[derive(Debug, thiserror::Error)] -pub enum MemberlistError { - /// Returns when the node is not running. - #[error("memberlist: node is not running, please bootstrap first")] - NotRunning, - /// Returns when timeout waiting for update broadcast. - #[error("memberlist: timeout waiting for update broadcast")] - UpdateTimeout, - /// Returns when timeout waiting for leave broadcast. - #[error("memberlist: timeout waiting for leave broadcast")] - LeaveTimeout, - /// Returns when lost connection with a peer. - #[error("memberlist: no response from node {0}")] - Lost(Node), - /// Returned when a message is received with an unexpected type. - #[error("memberlist: unexpected message: expected {expected}, got {got}")] - UnexpectedMessage { - /// The expected message type. - expected: &'static str, - /// The actual message type. - got: &'static str, - }, - /// Returned when the sequence number of [`Ack`](crate::types::Ack) is not - /// match the sequence number of [`Ping`](crate::types::Ping). - #[error("memberlist: sequence number mismatch: ping({ping}), ack({ack})")] - SequenceNumberMismatch { - /// The sequence number of [`Ping`](crate::types::Ping). - ping: u32, - /// The sequence number of [`Ack`](crate::types::Ack). - ack: u32, - }, - /// Returned when a remote error is received. - #[error("memberlist: remote error: {0}")] - Remote(SmolStr), - /// Returned when a custom error is created by users. - #[error("memberlist: {0}")] - Other(Cow<'static, str>), -} - /// Relay error from remote nodes. pub struct RelayError( #[allow(clippy::type_complexity)] diff --git a/serf-core/src/event.rs b/serf-core/src/event.rs index a0d6384..faeb747 100644 --- a/serf-core/src/event.rs +++ b/serf-core/src/event.rs @@ -1,7 +1,5 @@ use std::{pin::Pin, sync::Arc, task::Poll, time::Duration}; -use crate::delegate::TransformDelegate; - use self::error::Error; use super::{delegate::Delegate, types::Epoch, *}; @@ -18,7 +16,7 @@ use memberlist_core::{ CheapClone, bytes::{BufMut, Bytes, BytesMut}, transport::{AddressResolver, Transport}, - types::TinyVec, + proto::TinyVec, }; use serf_proto::{ LamportTime, Member, MessageType, Node, QueryFlag, QueryResponseMessage, UserEventMessage, @@ -101,11 +99,11 @@ where flags: QueryFlag::empty(), payload: msg, }; - let expected_encoded_len = ::message_encoded_len(&resp); + let expected_encoded_len = ::message_encoded_len(&resp); let mut buf = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type byte buf.put_u8(MessageType::QueryResponse as u8); buf.resize(expected_encoded_len + 1, 0); - let len = ::encode_message(&resp, &mut buf[1..]) + let len = ::encode_message(&resp, &mut buf[1..]) .map_err(Error::transform_delegate)?; debug_assert_eq!( len, expected_encoded_len, diff --git a/serf-core/src/event/crate_event.rs b/serf-core/src/event/crate_event.rs index 2b7ba33..df71d7b 100644 --- a/serf-core/src/event/crate_event.rs +++ b/serf-core/src/event/crate_event.rs @@ -3,13 +3,13 @@ use serf_proto::QueryMessage; use super::*; pub(crate) trait QueryMessageExt { - fn decode_internal_query( + fn decode_internal_query( &self, ) -> Option, T::Error>>; } impl QueryMessageExt for QueryMessage { - fn decode_internal_query( + fn decode_internal_query( &self, ) -> Option, T::Error>> { Some(Ok(match self.name().as_str() { diff --git a/serf-core/src/key_manager.rs b/serf-core/src/key_manager.rs index 3ee43c4..10ab920 100644 --- a/serf-core/src/key_manager.rs +++ b/serf-core/src/key_manager.rs @@ -19,7 +19,7 @@ use crate::event::{ use super::{ Serf, - delegate::{Delegate, TransformDelegate}, + delegate::{Delegate, }, error::Error, serf::{NodeResponse, QueryResponse}, types::{KeyRequestMessage, MessageType, SerfMessage}, @@ -186,12 +186,12 @@ where event: InternalQueryEvent, ) -> Result, Error> { let kr = KeyRequestMessage { key }; - let expected_encoded_len = ::message_encoded_len(&kr); + let expected_encoded_len = ::message_encoded_len(&kr); let mut buf = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type buf.put_u8(MessageType::KeyRequest as u8); buf.resize(expected_encoded_len + 1, 0); // Encode the query request - let len = ::encode_message(&kr, &mut buf[1..]) + let len = ::encode_message(&kr, &mut buf[1..]) .map_err(Error::transform_delegate)?; debug_assert_eq!( @@ -266,7 +266,7 @@ where } let node_response = - match ::decode_message(MessageType::KeyResponse, &r.payload[1..]) { + match ::decode_message(MessageType::KeyResponse, &r.payload[1..]) { Ok((_, nr)) => match nr { SerfMessage::KeyResponse(kr) => kr, msg => { diff --git a/serf-core/src/lib.rs b/serf-core/src/lib.rs index c38d3f1..cc9cba2 100644 --- a/serf-core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -1,127 +1,127 @@ -// #![doc = include_str!("../../README.md")] -// #![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] -// #![forbid(unsafe_code)] -// #![deny(warnings, missing_docs)] -// #![allow(clippy::type_complexity)] -// #![cfg_attr(docsrs, feature(doc_cfg))] -// #![cfg_attr(docsrs, allow(unused_attributes))] - -// pub(crate) mod broadcast; - -// mod coalesce; - -// /// Coordinate. -// pub mod coordinate; - -// /// Events for [`Serf`] -// pub mod event; - -// /// Errors for `serf`. -// pub mod error; - -// /// Delegate traits and its implementations. -// pub mod delegate; - -// mod options; -// pub use options::*; - -// /// The types used in `serf`. -// pub mod types; - -// /// Secret key management. -// #[cfg(feature = "encryption")] -// #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] -// pub mod key_manager; - -// mod serf; -// pub use serf::*; - -// mod snapshot; -// pub use snapshot::*; - -// fn invalid_data_io_error(e: E) -> std::io::Error { -// std::io::Error::new(std::io::ErrorKind::InvalidData, e) -// } - -// /// All unit test fns are exported in the `tests` module. -// /// This module is used for users want to use other async runtime, -// /// and want to use the test if memberlist also works with their runtime. -// #[cfg(feature = "test")] -// #[cfg_attr(docsrs, doc(cfg(feature = "test")))] -// pub mod tests { -// pub use memberlist_core::tests::{AnyError, next_socket_addr_v4, next_socket_addr_v6}; -// pub use paste; - -// pub use super::serf::base::tests::{serf::*, *}; - -// /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) -// #[cfg(any(feature = "test", test))] -// #[cfg_attr(docsrs, doc(cfg(any(feature = "test", test))))] -// #[macro_export] -// macro_rules! unit_tests { -// ($runtime:ty => $run:ident($($fn:ident), +$(,)?)) => { -// $( -// ::serf_core::tests::paste::paste! { -// #[test] -// fn [< test_ $fn >] () { -// $run($fn::<$runtime>()); -// } -// } -// )* -// }; -// } - -// /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) -// #[cfg(any(feature = "test", test))] -// #[cfg_attr(docsrs, doc(cfg(any(feature = "test", test))))] -// #[macro_export] -// macro_rules! unit_tests_with_expr { -// ($run:ident($( -// $(#[$outer:meta])* -// $fn:ident( $expr:expr ) -// ), +$(,)?)) => { -// $( -// ::serf_core::tests::paste::paste! { -// #[test] -// $(#[$outer])* -// fn [< test_ $fn >] () { -// $run(async move { -// $expr -// }); -// } -// } -// )* -// }; -// } - -// /// Initialize the tracing for the unit tests. -// pub fn initialize_tests_tracing() { -// use std::sync::Once; -// static TRACE: Once = Once::new(); -// TRACE.call_once(|| { -// let filter = std::env::var("RUSERF_TESTING_LOG") -// .unwrap_or_else(|_| "serf_core=info,memberlist_core=debug".to_owned()); -// memberlist_core::tracing::subscriber::set_global_default( -// tracing_subscriber::fmt::fmt() -// .without_time() -// .with_line_number(true) -// .with_env_filter(filter) -// .with_file(false) -// .with_target(true) -// .with_ansi(true) -// .finish(), -// ) -// .unwrap(); -// }); -// } - -// /// Run the unit test with a given async runtime sequentially. -// pub fn run(block_on: B, fut: F) -// where -// B: FnOnce(F) -> F::Output, -// F: std::future::Future, -// { -// // initialize_tests_tracing(); -// block_on(fut); -// } -// } +#![doc = include_str!("../../README.md")] +#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] +#![forbid(unsafe_code)] +#![deny(warnings, missing_docs)] +#![allow(clippy::type_complexity)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, allow(unused_attributes))] + +pub(crate) mod broadcast; + +mod coalesce; + +/// Coordinate. +pub mod coordinate; + +/// Events for [`Serf`] +pub mod event; + +/// Errors for `serf`. +pub mod error; + +/// Delegate traits and its implementations. +pub mod delegate; + +mod options; +pub use options::*; + +/// The types used in `serf`. +pub mod types; + +/// Secret key management. +#[cfg(feature = "encryption")] +#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] +pub mod key_manager; + +mod serf; +pub use serf::*; + +mod snapshot; +pub use snapshot::*; + +fn invalid_data_io_error(e: E) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::InvalidData, e) +} + +/// All unit test fns are exported in the `tests` module. +/// This module is used for users want to use other async runtime, +/// and want to use the test if memberlist also works with their runtime. +#[cfg(feature = "test")] +#[cfg_attr(docsrs, doc(cfg(feature = "test")))] +pub mod tests { + pub use memberlist_core::tests::{AnyError, next_socket_addr_v4, next_socket_addr_v6}; + pub use paste; + + pub use super::serf::base::tests::{serf::*, *}; + + /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) + #[cfg(any(feature = "test", test))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "test", test))))] + #[macro_export] + macro_rules! unit_tests { + ($runtime:ty => $run:ident($($fn:ident), +$(,)?)) => { + $( + ::serf_core::tests::paste::paste! { + #[test] + fn [< test_ $fn >] () { + $run($fn::<$runtime>()); + } + } + )* + }; + } + + /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) + #[cfg(any(feature = "test", test))] + #[cfg_attr(docsrs, doc(cfg(any(feature = "test", test))))] + #[macro_export] + macro_rules! unit_tests_with_expr { + ($run:ident($( + $(#[$outer:meta])* + $fn:ident( $expr:expr ) + ), +$(,)?)) => { + $( + ::serf_core::tests::paste::paste! { + #[test] + $(#[$outer])* + fn [< test_ $fn >] () { + $run(async move { + $expr + }); + } + } + )* + }; + } + + /// Initialize the tracing for the unit tests. + pub fn initialize_tests_tracing() { + use std::sync::Once; + static TRACE: Once = Once::new(); + TRACE.call_once(|| { + let filter = std::env::var("RUSERF_TESTING_LOG") + .unwrap_or_else(|_| "serf_core=info,memberlist_core=debug".to_owned()); + memberlist_core::tracing::subscriber::set_global_default( + tracing_subscriber::fmt::fmt() + .without_time() + .with_line_number(true) + .with_env_filter(filter) + .with_file(false) + .with_target(true) + .with_ansi(true) + .finish(), + ) + .unwrap(); + }); + } + + /// Run the unit test with a given async runtime sequentially. + pub fn run(block_on: B, fut: F) + where + B: FnOnce(F) -> F::Output, + F: std::future::Future, + { + // initialize_tests_tracing(); + block_on(fut); + } +} diff --git a/serf-core/src/options.rs b/serf-core/src/options.rs index f85249d..b59c8bc 100644 --- a/serf-core/src/options.rs +++ b/serf-core/src/options.rs @@ -559,7 +559,7 @@ pub(crate) struct QueueOptions { pub(crate) check_interval: Duration, pub(crate) depth_warning: usize, #[cfg(feature = "metrics")] - pub(crate) metric_labels: Arc, + pub(crate) metric_labels: Arc, } #[cfg(feature = "serde")] diff --git a/serf-core/src/serf.rs b/serf-core/src/serf.rs index d392fce..7a6b53e 100644 --- a/serf-core/src/serf.rs +++ b/serf-core/src/serf.rs @@ -11,7 +11,7 @@ use memberlist_core::{ agnostic_lite::{AsyncSpawner, RuntimeLite}, queue::TransmitLimitedQueue, transport::{AddressResolver, Transport}, - types::MediumVec, + proto::MediumVec, }; use super::{ diff --git a/serf-core/src/serf/api.rs b/serf-core/src/serf/api.rs index 98f3ad0..d40dd94 100644 --- a/serf-core/src/serf/api.rs +++ b/serf-core/src/serf/api.rs @@ -2,16 +2,12 @@ use std::sync::atomic::Ordering; use futures::{FutureExt, StreamExt}; use memberlist_core::{ - CheapClone, - bytes::{BufMut, Bytes, BytesMut}, - tracing, - transport::{MaybeResolvedAddress, Node}, - types::{Meta, OneOrMore, SmallVec}, + bytes::{BufMut, Bytes, BytesMut}, proto::Data, tracing, transport::{MaybeResolvedAddress, Node}, types::{Meta, OneOrMore, SmallVec}, CheapClone }; use smol_str::SmolStr; use crate::{ - delegate::TransformDelegate, + delegate::, error::{Error, JoinError}, event::EventProducer, types::{LeaveMessage, Member, MessageType, SerfMessage, Tags, UserEventMessage}, @@ -224,7 +220,7 @@ where #[inline] pub async fn set_tags(&self, tags: Tags) -> Result<(), Error> { // Check that the meta data length is okay - let tags_encoded_len = ::tags_encoded_len(&tags); + let tags_encoded_len = tags.encoded_len_with_length_delimited(); if tags_encoded_len > Meta::MAX_SIZE { return Err(Error::tags_too_large(tags_encoded_len)); } @@ -274,7 +270,7 @@ where }; // Start broadcasting the event - let len = ::message_encoded_len(&msg); + let len = ::message_encoded_len(&msg); // Check the size after encoding to be sure again that // we're not attempting to send over the specified size limit. @@ -290,7 +286,7 @@ where raw.put_u8(MessageType::UserEvent as u8); raw.resize(len + 1, 0); - let actual_encoded_len = ::encode_message(&msg, &mut raw[1..]) + let actual_encoded_len = ::encode_message(&msg, &mut raw[1..]) .map_err(Error::transform_delegate)?; debug_assert_eq!( actual_encoded_len, len, diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index 56b90e6..248d57c 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -17,7 +17,7 @@ use crate::{ QueueOptions, coalesce::{MemberEventCoalescer, UserEventCoalescer, coalesced_event}, coordinate::CoordinateOptions, - delegate::TransformDelegate, + delegate::, error::Error, event::{InternalQueryEvent, MemberEvent, MemberEventType, QueryContext, QueryEvent}, snapshot::{Snapshot, open_and_replay_snapshot}, @@ -74,7 +74,7 @@ where { let tags = opts.tags.load(); if !tags.as_ref().is_empty() { - let len = ::tags_encoded_len(&tags); + let len = ::tags_encoded_len(&tags); if len > Meta::MAX_SIZE { return Err(Error::tags_too_large(len)); } @@ -366,11 +366,11 @@ where notify_tx: Option>, ) -> Result<(), Error> { let ty = MessageType::from(&msg); - let expected_encoded_len = ::message_encoded_len(&msg); + let expected_encoded_len = ::message_encoded_len(&msg); let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // + 1 for message type byte raw.put_u8(ty as u8); raw.resize(expected_encoded_len + 1, 0); - let len = ::encode_message(&msg, &mut raw[1..]) + let len = ::encode_message(&msg, &mut raw[1..]) .map_err(Error::transform_delegate)?; debug_assert_eq!( len, expected_encoded_len, @@ -913,7 +913,7 @@ where }; // Encode the query - let len = ::message_encoded_len(&q); + let len = ::message_encoded_len(&q); // Check the size if len > self.inner.opts.query_size_limit { @@ -923,7 +923,7 @@ where let mut raw = BytesMut::with_capacity(len + 1); // + 1 for message type byte raw.put_u8(MessageType::Query as u8); raw.resize(len + 1, 0); - let actual_encoded_len = ::encode_message(&q, &mut raw[1..]) + let actual_encoded_len = ::encode_message(&q, &mut raw[1..]) .map_err(Error::transform_delegate)?; debug_assert_eq!( actual_encoded_len, len, @@ -1067,12 +1067,12 @@ where payload: Bytes::new(), }; - let expected_encoded_len = ::message_encoded_len(&ack); + let expected_encoded_len = ::message_encoded_len(&ack); let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // + 1 for message type byte raw.put_u8(MessageType::QueryResponse as u8); raw.resize(expected_encoded_len + 1, 0); - match ::encode_message(&ack, &mut raw[1..]) { + match ::encode_message(&ack, &mut raw[1..]) { Ok(len) => { debug_assert_eq!( len, expected_encoded_len, @@ -1181,7 +1181,7 @@ where let node = n.node(); let tags = if !n.meta().is_empty() { - match ::decode_tags(n.meta()) { + match ::decode_tags(n.meta()) { Ok((readed, tags)) => { tracing::trace!(read = %readed, tags=?tags, "serf: decode tags successfully"); tags @@ -1544,7 +1544,7 @@ where &self, n: Arc::ResolvedAddress>>, ) { - let tags = match ::decode_tags(n.meta()) { + let tags = match ::decode_tags(n.meta()) { Ok((readed, tags)) => { tracing::trace!(read = %readed, tags=?tags, "serf: decode tags successfully"); tags @@ -1662,10 +1662,10 @@ where // Get the local node let local_id = self.inner.memberlist.local_id(); let local_advertise_addr = self.inner.memberlist.advertise_address(); - let encoded_id_len = ::id_encoded_len(local_id); + let encoded_id_len = ::id_encoded_len(local_id); let mut payload = vec![0u8; encoded_id_len]; - if let Err(e) = ::encode_id(local_id, &mut payload) { + if let Err(e) = ::encode_id(local_id, &mut payload) { tracing::error!(err=%e, "serf: failed to encode local id"); return; } @@ -1699,7 +1699,7 @@ where continue; } - match ::decode_message(MessageType::ConflictResponse, &r.payload[1..]) + match ::decode_message(MessageType::ConflictResponse, &r.payload[1..]) { Ok((_, decoded)) => { match decoded { diff --git a/serf-core/src/serf/base/tests.rs b/serf-core/src/serf/base/tests.rs index 4c0459d..f7c4564 100644 --- a/serf-core/src/serf/base/tests.rs +++ b/serf-core/src/serf/base/tests.rs @@ -15,7 +15,7 @@ use serf_proto::{ use smol_str::SmolStr; use crate::{ - delegate::TransformDelegate, + delegate::, event::{CrateEvent, CrateEventType, MemberEvent, MemberEventType}, types::Epoch, }; @@ -339,7 +339,7 @@ pub async fn estimate_max_keys_in_list_key_response_factor( ) where T: Transport, { - use memberlist_core::types::SecretKey; + use memberlist_core::proto::SecretKey; use serf_proto::KeyResponseMessage; let size_limit = opts.query_response_size_limit() * 10; @@ -364,14 +364,14 @@ pub async fn estimate_max_keys_in_list_key_response_factor( let mut found = 0; for i in (0..=resp.keys.len()).rev() { - let encoded_len = as TransformDelegate>::message_encoded_len(&resp); + let encoded_len = as >::message_encoded_len(&resp); let mut dst = vec![0; encoded_len]; - as TransformDelegate>::encode_message(&resp, &mut dst).unwrap(); + as >::encode_message(&resp, &mut dst).unwrap(); let qresp = query.create_response(dst.into()); - let encoded_len = as TransformDelegate>::message_encoded_len(&qresp); + let encoded_len = as >::message_encoded_len(&qresp); let mut dst = vec![0; encoded_len]; - as TransformDelegate>::encode_message(&qresp, &mut dst).unwrap(); + as >::encode_message(&qresp, &mut dst).unwrap(); if query.check_response_size(&dst).is_err() { resp.keys.truncate(i); @@ -399,7 +399,7 @@ pub async fn key_list_key_response_with_correct_size(transport_opts: T::Optio where T: Transport, { - use memberlist_core::types::SecretKey; + use memberlist_core::proto::SecretKey; use serf_proto::{Encodable, KeyResponseMessage}; let opts = opts.with_query_response_size_limit(1024); diff --git a/serf-core/src/serf/base/tests/serf.rs b/serf-core/src/serf/base/tests/serf.rs index c68ef27..08209e6 100644 --- a/serf-core/src/serf/base/tests/serf.rs +++ b/serf-core/src/serf/base/tests/serf.rs @@ -782,7 +782,7 @@ where /// Unit test for serf write keying file #[cfg(feature = "encryption")] pub async fn serf_write_keyring_file( - get_transport_opts: impl FnOnce(memberlist_core::types::SecretKey) -> T::Options, + get_transport_opts: impl FnOnce(memberlist_core::proto::SecretKey) -> T::Options, ) where T: Transport, { @@ -798,7 +798,7 @@ pub async fn serf_write_keyring_file( p.set_extension("json"); let existing_bytes = general_purpose::STANDARD.decode(EXISTING).unwrap(); - let sk = memberlist_core::types::SecretKey::try_from(existing_bytes.as_slice()).unwrap(); + let sk = memberlist_core::proto::SecretKey::try_from(existing_bytes.as_slice()).unwrap(); let serf = Serf::::new( get_transport_opts(sk), @@ -813,7 +813,7 @@ pub async fn serf_write_keyring_file( let manager = serf.key_manager(); let new_key = general_purpose::STANDARD.decode(NEW_KEY).unwrap(); - let new_sk = memberlist_core::types::SecretKey::try_from(new_key.as_slice()).unwrap(); + let new_sk = memberlist_core::proto::SecretKey::try_from(new_key.as_slice()).unwrap(); manager.install_key(new_sk, None).await.unwrap(); let mut keyring_file = std::fs::File::open(&p).unwrap(); diff --git a/serf-core/src/serf/base/tests/serf/delegate.rs b/serf-core/src/serf/base/tests/serf/delegate.rs index 8138c1d..c81f10a 100644 --- a/serf-core/src/serf/base/tests/serf/delegate.rs +++ b/serf-core/src/serf/base/tests/serf/delegate.rs @@ -14,7 +14,7 @@ where .unwrap(); let meta = s.inner.memberlist.delegate().unwrap().node_meta(32).await; - let (_, tags) = as TransformDelegate>::decode_tags(&meta).unwrap(); + let (_, tags) = as >::decode_tags(&meta).unwrap(); assert_eq!(tags.get("role"), Some(&SmolStr::new("test"))); s.shutdown().await.unwrap(); @@ -82,7 +82,7 @@ where // Attempt a decode let (_, pp) = - as TransformDelegate>::decode_message(MessageType::PushPull, &buf[1..]) + as >::decode_message(MessageType::PushPull, &buf[1..]) .unwrap(); let SerfMessage::PushPull(pp) = pp else { @@ -147,9 +147,9 @@ where query_ltime: 100.into(), }; - let mut buf = vec![0; as TransformDelegate>::message_encoded_len(&pp) + 1]; + let mut buf = vec![0; as >::message_encoded_len(&pp) + 1]; buf[0] = MessageType::PushPull as u8; - as TransformDelegate>::encode_message(&pp, &mut buf[1..]).unwrap(); + as >::encode_message(&pp, &mut buf[1..]).unwrap(); // Merge in fake state d.merge_remote_state(buf.into(), false).await; diff --git a/serf-core/src/serf/base/tests/serf/event.rs b/serf-core/src/serf/base/tests/serf/event.rs index 7f4ddbe..0a36144 100644 --- a/serf-core/src/serf/base/tests/serf/event.rs +++ b/serf-core/src/serf/base/tests/serf/event.rs @@ -571,15 +571,15 @@ where assert_eq!(filters.len(), 3); let (_, node_filt) = - as TransformDelegate>::decode_filter(&filters[0]).unwrap(); + as >::decode_filter(&filters[0]).unwrap(); assert_eq!(node_filt.ty(), FilterType::Id); let (_, tag_filt) = - as TransformDelegate>::decode_filter(&filters[1]).unwrap(); + as >::decode_filter(&filters[1]).unwrap(); assert_eq!(tag_filt.ty(), FilterType::Tag); let (_, tag_filt) = - as TransformDelegate>::decode_filter(&filters[2]).unwrap(); + as >::decode_filter(&filters[2]).unwrap(); assert_eq!(tag_filt.ty(), FilterType::Tag); } diff --git a/serf-core/src/serf/base/tests/serf/join.rs b/serf-core/src/serf/base/tests/serf/join.rs index 4ee840b..08916ce 100644 --- a/serf-core/src/serf/base/tests/serf/join.rs +++ b/serf-core/src/serf/base/tests/serf/join.rs @@ -293,7 +293,7 @@ pub async fn join_pending_intent( id: "test".into(), addr, meta: Meta::empty(), - state: memberlist_core::types::State::Alive, + state: memberlist_core::proto::State::Alive, protocol_version: serf_proto::MemberlistProtocolVersion::V1, delegate_version: serf_proto::MemberlistDelegateVersion::V1, })) @@ -340,7 +340,7 @@ pub async fn join_pending_intents( id: "test".into(), addr, meta: Meta::empty(), - state: memberlist_core::types::State::Alive, + state: memberlist_core::proto::State::Alive, protocol_version: serf_proto::MemberlistProtocolVersion::V1, delegate_version: serf_proto::MemberlistDelegateVersion::V1, })) diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index 9db1949..0ca379f 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -1,7 +1,7 @@ use crate::{ Serf, broadcast::SerfBroadcast, - delegate::{Delegate, TransformDelegate}, + delegate::{Delegate, }, error::{SerfDelegateError, SerfError}, event::QueryMessageExt, types::{ @@ -117,7 +117,7 @@ where let tags = self.tags.load(); match tags.is_empty() { false => { - let encoded_len = ::tags_encoded_len(&tags); + let encoded_len = ::tags_encoded_len(&tags); let limit = limit.min(Meta::MAX_SIZE); if encoded_len > limit { panic!( @@ -127,7 +127,7 @@ where } let mut role_bytes = vec![0; encoded_len]; - match ::encode_tags(&tags, &mut role_bytes) { + match ::encode_tags(&tags, &mut role_bytes) { Ok(len) => { debug_assert_eq!( len, encoded_len, @@ -190,7 +190,7 @@ where } match ty { - MessageType::Leave => match ::decode_message(ty, &msg[1..]) { + MessageType::Leave => match ::decode_message(ty, &msg[1..]) { Ok((_, l)) => { if let SerfMessage::Leave(l) = &l { tracing::debug!("serf: leave message: {}", l.id()); @@ -203,7 +203,7 @@ where tracing::warn!(err=%e, "serf: failed to decode message"); } }, - MessageType::Join => match ::decode_message(ty, &msg[1..]) { + MessageType::Join => match ::decode_message(ty, &msg[1..]) { Ok((_, j)) => { if let SerfMessage::Join(j) = &j { tracing::debug!("serf: join message: {}", j.id()); @@ -216,7 +216,7 @@ where tracing::warn!(err=%e, "serf: failed to decode message"); } }, - MessageType::UserEvent => match ::decode_message(ty, &msg[1..]) { + MessageType::UserEvent => match ::decode_message(ty, &msg[1..]) { Ok((_, ue)) => { if let SerfMessage::UserEvent(ue) = ue { tracing::debug!("serf: user event message: {}", ue.name); @@ -230,7 +230,7 @@ where tracing::warn!(err=%e, "serf: failed to decode message"); } }, - MessageType::Query => match ::decode_message(ty, &msg[1..]) { + MessageType::Query => match ::decode_message(ty, &msg[1..]) { Ok((_, q)) => { if let SerfMessage::Query(q) = q { tracing::debug!("serf: query message: {}", q.name); @@ -256,7 +256,7 @@ where } }, MessageType::QueryResponse => { - match ::decode_message(ty, &msg[1..]) { + match ::decode_message(ty, &msg[1..]) { Ok((_, qr)) => { if let SerfMessage::QueryResponse(qr) = qr { tracing::debug!("serf: query response message: {}", qr.from); @@ -270,7 +270,7 @@ where } } } - MessageType::Relay => match ::decode_node(&msg[1..]) { + MessageType::Relay => match ::decode_node(&msg[1..]) { Ok((consumed, n)) => { tracing::debug!("serf: relay message",); tracing::debug!("serf: relaying response to node: {}", n); @@ -399,11 +399,11 @@ where }; drop(members); - let expected_encoded_len = ::message_encoded_len(pp); + let expected_encoded_len = ::message_encoded_len(pp); let mut buf = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type byte buf.put_u8(MessageType::PushPull as u8); buf.resize(expected_encoded_len + 1, 0); - match ::encode_message(pp, &mut buf[1..]) { + match ::encode_message(pp, &mut buf[1..]) { Ok(encoded_len) => { debug_assert_eq!( expected_encoded_len, encoded_len, @@ -449,7 +449,7 @@ where match ty { MessageType::PushPull => { - match ::decode_message(ty, &buf[1..]) { + match ::decode_message(ty, &buf[1..]) { Err(e) => { tracing::error!(err=%e, "serf: failed to decode remote state"); } @@ -677,9 +677,9 @@ where coord.portion.resize(len * 2, 0.0); // The rest of the message is the serialized coordinate. - let len = ::coordinate_encoded_len(&coord); + let len = ::coordinate_encoded_len(&coord); buf.resize(len + 1, 0); - if let Err(e) = ::encode_coordinate(&coord, &mut buf[1..]) { + if let Err(e) = ::encode_coordinate(&coord, &mut buf[1..]) { panic!("failed to encode coordinate: {}", e); } return buf.freeze(); @@ -687,12 +687,12 @@ where if let Some(c) = self.this().inner.coord_core.as_ref() { let coord = c.client.get_coordinate(); - let encoded_len = ::coordinate_encoded_len(&coord) + 1; + let encoded_len = ::coordinate_encoded_len(&coord) + 1; let mut buf = BytesMut::with_capacity(encoded_len); buf.put_u8(PING_VERSION); buf.resize(encoded_len, 0); - if let Err(e) = ::encode_coordinate(&coord, &mut buf[1..]) { + if let Err(e) = ::encode_coordinate(&coord, &mut buf[1..]) { tracing::error!(err=%e, "serf: failed to encode coordinate"); } buf.into() @@ -721,7 +721,7 @@ where } // Process the remainder of the message as a coordinate. - let coord = match ::decode_coordinate(&payload[1..]) { + let coord = match ::decode_coordinate(&payload[1..]) { Ok((readed, c)) => { tracing::trace!(read=%readed, coordinate=?c, "serf: decode coordinate successfully"); c @@ -815,7 +815,7 @@ where Ok(Member { node: node.node(), tags: if !node.meta().is_empty() { - ::decode_tags(node.meta()) + ::decode_tags(node.meta()) .map(|(read, tags)| { tracing::trace!(read=%read, tags=?tags, "serf: decode tags successfully"); Arc::new(tags) diff --git a/serf-core/src/serf/internal_query.rs b/serf-core/src/serf/internal_query.rs index 5503ce6..5f07c65 100644 --- a/serf-core/src/serf/internal_query.rs +++ b/serf-core/src/serf/internal_query.rs @@ -8,7 +8,7 @@ use memberlist_core::{ }; use crate::{ - delegate::{Delegate, TransformDelegate}, + delegate::{Delegate, }, event::{CrateEvent, InternalQueryEvent, QueryEvent}, types::MessageType, }; @@ -152,11 +152,11 @@ where match out { Some(state) => { let member = state.member(); - let expected_encoded_len = ::message_encoded_len(member); + let expected_encoded_len = ::message_encoded_len(member); let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type raw.put_u8(MessageType::ConflictResponse as u8); raw.resize(expected_encoded_len + 1, 0); - match ::encode_message(member, &mut raw[1..]) { + match ::encode_message(member, &mut raw[1..]) { Ok(len) => { debug_assert_eq!( len, expected_encoded_len, @@ -193,7 +193,7 @@ where let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); let req = - match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { + match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { Ok((_, msg)) => match msg { SerfMessage::KeyRequest(req) => req, msg => { @@ -257,7 +257,7 @@ where let mut response = KeyResponseMessage::default(); let req = - match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { + match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { Ok((_, msg)) => match msg { SerfMessage::KeyRequest(req) => req, msg => { @@ -327,7 +327,7 @@ where let mut response = KeyResponseMessage::default(); let req = - match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { + match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { Ok((_, msg)) => match msg { SerfMessage::KeyRequest(req) => req, msg => { @@ -450,12 +450,12 @@ where (q.ctx.this.inner.opts.query_response_size_limit / MIN_ENCODED_KEY_LENGTH).min(actual); for i in (0..=max_list_keys).rev() { - let expected_k_encoded_len = ::message_encoded_len(&*resp); + let expected_k_encoded_len = ::message_encoded_len(&*resp); let mut raw = BytesMut::with_capacity(expected_k_encoded_len + 1); // +1 for the message type raw.put_u8(MessageType::KeyResponse as u8); raw.resize(expected_k_encoded_len + 1, 0); - let len = ::encode_message(&*resp, &mut raw[1..]) + let len = ::encode_message(&*resp, &mut raw[1..]) .map_err(Error::transform_delegate)?; debug_assert_eq!( @@ -469,12 +469,12 @@ where let qresp = q.create_response(kraw.clone()); // encode response - let expected_encoded_len = ::message_encoded_len(&qresp); + let expected_encoded_len = ::message_encoded_len(&qresp); let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type raw.put_u8(MessageType::QueryResponse as u8); raw.resize(expected_encoded_len + 1, 0); - let len = ::encode_message(&qresp, &mut raw[1..]) + let len = ::encode_message(&qresp, &mut raw[1..]) .map_err(Error::transform_delegate)?; debug_assert_eq!( @@ -520,11 +520,11 @@ where } } _ => { - let expected_encoded_len = ::message_encoded_len(&*resp); + let expected_encoded_len = ::message_encoded_len(&*resp); let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type raw.put_u8(MessageType::KeyResponse as u8); raw.resize(expected_encoded_len + 1, 0); - match ::encode_message(&*resp, &mut raw[1..]) { + match ::encode_message(&*resp, &mut raw[1..]) { Ok(len) => { debug_assert_eq!( len, expected_encoded_len, diff --git a/serf-core/src/serf/query.rs b/serf-core/src/serf/query.rs index e7b218e..a8e59ae 100644 --- a/serf-core/src/serf/query.rs +++ b/serf-core/src/serf/query.rs @@ -16,7 +16,7 @@ use memberlist_core::{ }; use crate::{ - delegate::{Delegate, TransformDelegate}, + delegate::{Delegate, }, error::Error, types::{ Filter, LamportTime, Member, MemberStatus, MessageType, QueryMessage, QueryResponseMessage, @@ -92,23 +92,6 @@ pub struct QueryParam { timeout: Duration, } -impl QueryParam -where - I: Id, -{ - /// Used to convert the filters into the wire format - pub(crate) fn encode_filters>( - &self, - ) -> Result, W::Error> { - let mut filters = TinyVec::with_capacity(self.filters.len()); - for filter in self.filters.iter() { - filters.push(W::encode_filter(filter)?); - } - - Ok(filters) - } -} - struct QueryResponseChannel { /// Used to send the name of a node for which we've received an ack ack_ch: Option<(Sender>, Receiver>)>, @@ -258,7 +241,7 @@ impl QueryResponse { &self, resp: QueryResponseMessage, _local: &T::Id, - #[cfg(feature = "metrics")] metrics_labels: &memberlist_core::types::MetricLabels, + #[cfg(feature = "metrics")] metrics_labels: &memberlist_core::proto::MetricLabels, ) where I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug, A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug, @@ -461,7 +444,7 @@ where } // Decode the filter - let filter = match ::decode_filter(filter) { + let filter = match ::decode_filter(filter) { Ok((read, filter)) => { tracing::trace!(read=%read, filter=?filter, "serf: decoded filter successully"); filter @@ -549,9 +532,9 @@ where // Prep the relay message, which is a wrapped version of the original. // let relay_msg = SerfRelayMessage::new(node, SerfMessage::QueryResponse(resp)); let expected_encoded_len = 1 - + ::node_encoded_len(&node) + + ::node_encoded_len(&node) + 1 - + ::message_encoded_len(&resp); // +1 for relay message type byte, +1 for the message type + + ::message_encoded_len(&resp); // +1 for relay message type byte, +1 for the message type if expected_encoded_len > self.inner.opts.query_response_size_limit { return Err(Error::relayed_response_too_large( self.inner.opts.query_response_size_limit, @@ -562,11 +545,11 @@ where raw.put_u8(MessageType::Relay as u8); raw.resize(expected_encoded_len + 1 + 1, 0); let mut encoded = 1; - encoded += ::encode_node(&node, &mut raw[encoded..]) + encoded += ::encode_node(&node, &mut raw[encoded..]) .map_err(Error::transform_delegate)?; raw[encoded] = MessageType::QueryResponse as u8; encoded += 1; - encoded += ::encode_message(&resp, &mut raw[encoded..]) + encoded += ::encode_message(&resp, &mut raw[encoded..]) .map_err(Error::transform_delegate)?; debug_assert_eq!( diff --git a/serf-core/src/snapshot.rs b/serf-core/src/snapshot.rs index fb11c76..bdb9e04 100644 --- a/serf-core/src/snapshot.rs +++ b/serf-core/src/snapshot.rs @@ -20,13 +20,13 @@ use memberlist_core::{ bytes::{BufMut, BytesMut}, tracing, transport::{AddressResolver, Id, MaybeResolvedAddress, Node, Transport}, - types::TinyVec, + proto::TinyVec, }; use rand::seq::SliceRandom; use serf_proto::UserEventMessage; use crate::{ - delegate::{Delegate, TransformDelegate}, + delegate::Delegate, event::{CrateEvent, MemberEvent, MemberEventType}, invalid_data_io_error, types::{Epoch, LamportClock, LamportTime}, @@ -200,7 +200,7 @@ where const LEAVE: u8 = 6; const COMMENT: u8 = 7; - fn encode, W: Write>( + fn encode( &self, w: &mut W, ) -> std::io::Result { @@ -231,7 +231,6 @@ pub(crate) struct ReplayResult { pub(crate) fn open_and_replay_snapshot< I: Id, A: CheapClone + core::hash::Hash + Eq + Send + Sync + 'static, - T: TransformDelegate, P: AsRef, >( p: &P, @@ -397,7 +396,7 @@ where wait_tx: Sender<()>, last_attempted_compaction: Epoch, #[cfg(feature = "metrics")] - metric_labels: std::sync::Arc, + metric_labels: std::sync::Arc, } // flushEvent is used to handle writing out an event @@ -446,7 +445,7 @@ where clock: LamportClock, out_tx: Sender>, shutdown_rx: Receiver<()>, - #[cfg(feature = "metrics")] metric_labels: std::sync::Arc, + #[cfg(feature = "metrics")] metric_labels: std::sync::Arc, ) -> Result< ( Sender>, @@ -502,7 +501,7 @@ where Node::new(id, MaybeResolvedAddress::resolved(addr)) }) .collect::>(); - alive_nodes.shuffle(&mut rand::thread_rng()); + alive_nodes.shuffle(&mut rand::rng()); // Start handling new commands let handle = ::spawn(Self::tee_stream( diff --git a/serf-core/src/types/member.rs b/serf-core/src/types/member.rs index 9f4a5f6..43b7b3a 100644 --- a/serf-core/src/types/member.rs +++ b/serf-core/src/types/member.rs @@ -1,9 +1,9 @@ -use memberlist_core::types::OneOrMore; +use memberlist_core::proto::OneOrMore; use serf_proto::Member; use std::collections::HashMap; -use super::{Epoch, LamportTime, MessageType}; +use super::{Epoch, LamportTime}; /// Used to track members that are no longer active due to /// leaving, failing, partitioning, etc. It tracks the member along with From 756a53496c5e5108169867ef3e44227432185547 Mon Sep 17 00:00:00 2001 From: al8n Date: Thu, 27 Feb 2025 16:00:08 +0800 Subject: [PATCH 06/39] WIP --- serf-core/src/coalesce/member.rs | 568 +++++++-------- serf-core/src/coalesce/user.rs | 276 +++---- serf-core/src/delegate/composite.rs | 28 +- serf-core/src/delegate/merge.rs | 4 +- serf-core/src/delegate/reconnect.rs | 2 +- serf-core/src/error.rs | 139 +--- serf-core/src/event.rs | 82 +-- serf-core/src/event/crate_event.rs | 20 +- serf-core/src/key_manager.rs | 69 +- serf-core/src/lib.rs | 4 +- serf-core/src/serf.rs | 2 +- serf-core/src/serf/api.rs | 81 +-- serf-core/src/serf/base.rs | 92 +-- serf-core/src/serf/base/tests.rs | 13 +- .../src/serf/base/tests/serf/delegate.rs | 16 +- serf-core/src/serf/base/tests/serf/event.rs | 6 +- serf-core/src/serf/delegate.rs | 77 +- serf-core/src/serf/internal_query.rs | 202 ++---- serf-core/src/serf/query.rs | 42 +- serf-core/src/snapshot.rs | 45 +- serf-core/src/types/member.rs | 2 +- serf-proto/src/conflict.rs | 174 +++++ serf-proto/src/lib.rs | 7 +- serf-proto/src/message.rs | 679 ++++++++---------- serf-proto/src/push_pull.rs | 147 +++- serf/test/main.rs | 48 +- 26 files changed, 1400 insertions(+), 1425 deletions(-) create mode 100644 serf-proto/src/conflict.rs diff --git a/serf-core/src/coalesce/member.rs b/serf-core/src/coalesce/member.rs index c439c8a..4753326 100644 --- a/serf-core/src/coalesce/member.rs +++ b/serf-core/src/coalesce/member.rs @@ -3,8 +3,8 @@ use std::{collections::HashMap, marker::PhantomData}; use async_channel::Sender; use memberlist_core::{ CheapClone, - transport::{AddressResolver, Node, Transport}, proto::TinyVec, + transport::{AddressResolver, Node, Transport}, }; use crate::{ @@ -118,286 +118,286 @@ where } } -#[cfg(all(test, feature = "test"))] -#[allow(clippy::collapsible_match)] -mod tests { - use std::{net::SocketAddr, time::Duration}; - - use futures::FutureExt; - use memberlist_core::{ - agnostic_lite::{RuntimeLite, tokio::TokioRuntime}, - transport::resolver::socket_addr::SocketAddrResolver, - }; - use serf_proto::{MemberStatus, UserEventMessage}; - use smol_str::SmolStr; - - use crate::{ - DefaultDelegate, - coalesce::coalesced_event, - event::{CrateEventType, MemberEvent}, - }; - - use super::*; - - type Transport = UnimplementedTransport< - SmolStr, - SocketAddrResolver, - Lpe, - TokioRuntime, - >; - - type Delegate = DefaultDelegate; - - #[tokio::test] - async fn test_member_event_coealesce_basic() { - let (tx, rx) = async_channel::unbounded(); - let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let coalescer = MemberEventCoalescer::::new(); - - let in_ = coalesced_event( - tx, - shutdown_rx, - Duration::from_millis(20), - Duration::from_millis(20), - coalescer, - ); - - let send = vec![ - MemberEvent { - ty: MemberEventType::Join, - members: TinyVec::from(Member::new( - Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), - Default::default(), - MemberStatus::None, - )) - .into(), - }, - MemberEvent { - ty: MemberEventType::Leave, - members: TinyVec::from(Member::new( - Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), - Default::default(), - MemberStatus::None, - )) - .into(), - }, - MemberEvent { - ty: MemberEventType::Leave, - members: TinyVec::from(Member::new( - Node::new("bar".into(), "127.0.0.1:8080".parse().unwrap()), - Default::default(), - MemberStatus::None, - )) - .into(), - }, - MemberEvent { - ty: MemberEventType::Update, - members: TinyVec::from(Member::new( - Node::new("zip".into(), "127.0.0.1:8080".parse().unwrap()), - [("role", "foo")].into_iter().collect(), - MemberStatus::None, - )) - .into(), - }, - MemberEvent { - ty: MemberEventType::Update, - members: TinyVec::from(Member::new( - Node::new("zip".into(), "127.0.0.1:8080".parse().unwrap()), - [("role", "bar")].into_iter().collect(), - MemberStatus::None, - )) - .into(), - }, - MemberEvent { - ty: MemberEventType::Reap, - members: TinyVec::from(Member::new( - Node::new("dead".into(), "127.0.0.1:8080".parse().unwrap()), - Default::default(), - MemberStatus::None, - )) - .into(), - }, - ]; - - for event in send { - in_.send(CrateEvent::from(event)).await.unwrap(); - } - - let mut events = HashMap::new(); - let timeout = TokioRuntime::sleep(Duration::from_millis(40)); - futures::pin_mut!(timeout); - loop { - futures::select! { - e = rx.recv().fuse() => { - let e = e.unwrap(); - events.insert(e.ty(), e.clone()); - } - _ = (&mut timeout).fuse() => { - break; - }, - } - } - - assert_eq!(events.len(), 3); - - match events.get(&CrateEventType::Member(MemberEventType::Leave)) { - None => panic!(""), - Some(e) => match e { - CrateEvent::Member(MemberEvent { members, .. }) => { - assert_eq!(members.len(), 2); - - let expected = ["bar", "foo"]; - let mut names = [members[0].node.id().clone(), members[1].node.id().clone()]; - names.sort(); - - assert_eq!(names, expected); - } - _ => panic!(""), - }, - } - - match events.get(&CrateEventType::Member(MemberEventType::Update)) { - None => panic!(""), - Some(e) => match e { - CrateEvent::Member(MemberEvent { members, .. }) => { - assert_eq!(members.len(), 1); - assert_eq!(members[0].node.id(), "zip"); - assert_eq!(members[0].tags().get("role").unwrap(), "bar"); - } - _ => panic!(""), - }, - } - - match events.get(&CrateEventType::Member(MemberEventType::Reap)) { - None => panic!(""), - Some(e) => match e { - CrateEvent::Member(MemberEvent { members, .. }) => { - assert_eq!(members.len(), 1); - assert_eq!(members[0].node.id(), "dead"); - } - _ => panic!(""), - }, - } - } - - #[tokio::test] - async fn test_member_event_coalesce_tag_update() { - let (tx, rx) = async_channel::unbounded(); - let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let coalescer = MemberEventCoalescer::::new(); - - let in_ = coalesced_event( - tx, - shutdown_rx, - Duration::from_millis(5), - Duration::from_millis(5), - coalescer, - ); - - in_ - .send(CrateEvent::from(MemberEvent { - ty: MemberEventType::Update, - members: TinyVec::from(Member::new( - Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), - [("role", "foo")].into_iter().collect(), - MemberStatus::None, - )) - .into(), - })) - .await - .unwrap(); - - TokioRuntime::sleep(Duration::from_millis(30)).await; - - futures::select! { - e = rx.recv().fuse() => { - let e = e.unwrap(); - - match e { - CrateEvent::Member(MemberEvent { ty, .. }) => { - assert!(matches!(ty, MemberEventType::Update)); - } - _ => panic!("expected update"), - } - } - default => panic!("expected update"), - } - - // Second update should not be suppressed even though - // last event was an update - in_ - .send(CrateEvent::from(MemberEvent { - ty: MemberEventType::Update, - members: TinyVec::from(Member::new( - Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), - [("role", "bar")].into_iter().collect(), - MemberStatus::None, - )) - .into(), - })) - .await - .unwrap(); - TokioRuntime::sleep(Duration::from_millis(10)).await; - - futures::select! { - e = rx.recv().fuse() => { - let e = e.unwrap(); - - match e { - CrateEvent::Member(MemberEvent { ty, .. }) => { - assert!(matches!(ty, MemberEventType::Update)); - } - _ => panic!("expected update"), - } - } - default => panic!("expected update"), - } - } - - #[test] - fn test_member_event_coalesce_pass_through() { - let cases = [ - (CrateEvent::from(UserEventMessage::default()), false), - ( - CrateEvent::from(MemberEvent { - ty: MemberEventType::Join, - members: TinyVec::new().into(), - }), - true, - ), - ( - CrateEvent::from(MemberEvent { - ty: MemberEventType::Leave, - members: TinyVec::new().into(), - }), - true, - ), - ( - CrateEvent::from(MemberEvent { - ty: MemberEventType::Failed, - members: TinyVec::new().into(), - }), - true, - ), - ( - CrateEvent::from(MemberEvent { - ty: MemberEventType::Update, - members: TinyVec::new().into(), - }), - true, - ), - ( - CrateEvent::from(MemberEvent { - ty: MemberEventType::Reap, - members: TinyVec::new().into(), - }), - true, - ), - ]; - - for (event, handle) in cases.iter() { - let coalescer = MemberEventCoalescer::::new(); - assert_eq!(coalescer.handle(event), *handle); - } - } -} +// #[cfg(all(test, feature = "test"))] +// #[allow(clippy::collapsible_match)] +// mod tests { +// use std::{net::SocketAddr, time::Duration}; + +// use futures::FutureExt; +// use memberlist_core::{ +// agnostic_lite::{RuntimeLite, tokio::TokioRuntime}, +// transport::resolver::socket_addr::SocketAddrResolver, +// }; +// use serf_proto::{MemberStatus, UserEventMessage}; +// use smol_str::SmolStr; + +// use crate::{ +// DefaultDelegate, +// coalesce::coalesced_event, +// event::{CrateEventType, MemberEvent}, +// }; + +// use super::*; + +// type Transport = UnimplementedTransport< +// SmolStr, +// SocketAddrResolver, +// Lpe, +// TokioRuntime, +// >; + +// type Delegate = DefaultDelegate; + +// #[tokio::test] +// async fn test_member_event_coealesce_basic() { +// let (tx, rx) = async_channel::unbounded(); +// let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); +// let coalescer = MemberEventCoalescer::::new(); + +// let in_ = coalesced_event( +// tx, +// shutdown_rx, +// Duration::from_millis(20), +// Duration::from_millis(20), +// coalescer, +// ); + +// let send = vec![ +// MemberEvent { +// ty: MemberEventType::Join, +// members: TinyVec::from(Member::new( +// Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), +// Default::default(), +// MemberStatus::None, +// )) +// .into(), +// }, +// MemberEvent { +// ty: MemberEventType::Leave, +// members: TinyVec::from(Member::new( +// Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), +// Default::default(), +// MemberStatus::None, +// )) +// .into(), +// }, +// MemberEvent { +// ty: MemberEventType::Leave, +// members: TinyVec::from(Member::new( +// Node::new("bar".into(), "127.0.0.1:8080".parse().unwrap()), +// Default::default(), +// MemberStatus::None, +// )) +// .into(), +// }, +// MemberEvent { +// ty: MemberEventType::Update, +// members: TinyVec::from(Member::new( +// Node::new("zip".into(), "127.0.0.1:8080".parse().unwrap()), +// [("role", "foo")].into_iter().collect(), +// MemberStatus::None, +// )) +// .into(), +// }, +// MemberEvent { +// ty: MemberEventType::Update, +// members: TinyVec::from(Member::new( +// Node::new("zip".into(), "127.0.0.1:8080".parse().unwrap()), +// [("role", "bar")].into_iter().collect(), +// MemberStatus::None, +// )) +// .into(), +// }, +// MemberEvent { +// ty: MemberEventType::Reap, +// members: TinyVec::from(Member::new( +// Node::new("dead".into(), "127.0.0.1:8080".parse().unwrap()), +// Default::default(), +// MemberStatus::None, +// )) +// .into(), +// }, +// ]; + +// for event in send { +// in_.send(CrateEvent::from(event)).await.unwrap(); +// } + +// let mut events = HashMap::new(); +// let timeout = TokioRuntime::sleep(Duration::from_millis(40)); +// futures::pin_mut!(timeout); +// loop { +// futures::select! { +// e = rx.recv().fuse() => { +// let e = e.unwrap(); +// events.insert(e.ty(), e.clone()); +// } +// _ = (&mut timeout).fuse() => { +// break; +// }, +// } +// } + +// assert_eq!(events.len(), 3); + +// match events.get(&CrateEventType::Member(MemberEventType::Leave)) { +// None => panic!(""), +// Some(e) => match e { +// CrateEvent::Member(MemberEvent { members, .. }) => { +// assert_eq!(members.len(), 2); + +// let expected = ["bar", "foo"]; +// let mut names = [members[0].node.id().clone(), members[1].node.id().clone()]; +// names.sort(); + +// assert_eq!(names, expected); +// } +// _ => panic!(""), +// }, +// } + +// match events.get(&CrateEventType::Member(MemberEventType::Update)) { +// None => panic!(""), +// Some(e) => match e { +// CrateEvent::Member(MemberEvent { members, .. }) => { +// assert_eq!(members.len(), 1); +// assert_eq!(members[0].node.id(), "zip"); +// assert_eq!(members[0].tags().get("role").unwrap(), "bar"); +// } +// _ => panic!(""), +// }, +// } + +// match events.get(&CrateEventType::Member(MemberEventType::Reap)) { +// None => panic!(""), +// Some(e) => match e { +// CrateEvent::Member(MemberEvent { members, .. }) => { +// assert_eq!(members.len(), 1); +// assert_eq!(members[0].node.id(), "dead"); +// } +// _ => panic!(""), +// }, +// } +// } + +// #[tokio::test] +// async fn test_member_event_coalesce_tag_update() { +// let (tx, rx) = async_channel::unbounded(); +// let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); +// let coalescer = MemberEventCoalescer::::new(); + +// let in_ = coalesced_event( +// tx, +// shutdown_rx, +// Duration::from_millis(5), +// Duration::from_millis(5), +// coalescer, +// ); + +// in_ +// .send(CrateEvent::from(MemberEvent { +// ty: MemberEventType::Update, +// members: TinyVec::from(Member::new( +// Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), +// [("role", "foo")].into_iter().collect(), +// MemberStatus::None, +// )) +// .into(), +// })) +// .await +// .unwrap(); + +// TokioRuntime::sleep(Duration::from_millis(30)).await; + +// futures::select! { +// e = rx.recv().fuse() => { +// let e = e.unwrap(); + +// match e { +// CrateEvent::Member(MemberEvent { ty, .. }) => { +// assert!(matches!(ty, MemberEventType::Update)); +// } +// _ => panic!("expected update"), +// } +// } +// default => panic!("expected update"), +// } + +// // Second update should not be suppressed even though +// // last event was an update +// in_ +// .send(CrateEvent::from(MemberEvent { +// ty: MemberEventType::Update, +// members: TinyVec::from(Member::new( +// Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), +// [("role", "bar")].into_iter().collect(), +// MemberStatus::None, +// )) +// .into(), +// })) +// .await +// .unwrap(); +// TokioRuntime::sleep(Duration::from_millis(10)).await; + +// futures::select! { +// e = rx.recv().fuse() => { +// let e = e.unwrap(); + +// match e { +// CrateEvent::Member(MemberEvent { ty, .. }) => { +// assert!(matches!(ty, MemberEventType::Update)); +// } +// _ => panic!("expected update"), +// } +// } +// default => panic!("expected update"), +// } +// } + +// #[test] +// fn test_member_event_coalesce_pass_through() { +// let cases = [ +// (CrateEvent::from(UserEventMessage::default()), false), +// ( +// CrateEvent::from(MemberEvent { +// ty: MemberEventType::Join, +// members: TinyVec::new().into(), +// }), +// true, +// ), +// ( +// CrateEvent::from(MemberEvent { +// ty: MemberEventType::Leave, +// members: TinyVec::new().into(), +// }), +// true, +// ), +// ( +// CrateEvent::from(MemberEvent { +// ty: MemberEventType::Failed, +// members: TinyVec::new().into(), +// }), +// true, +// ), +// ( +// CrateEvent::from(MemberEvent { +// ty: MemberEventType::Update, +// members: TinyVec::new().into(), +// }), +// true, +// ), +// ( +// CrateEvent::from(MemberEvent { +// ty: MemberEventType::Reap, +// members: TinyVec::new().into(), +// }), +// true, +// ), +// ]; + +// for (event, handle) in cases.iter() { +// let coalescer = MemberEventCoalescer::::new(); +// assert_eq!(coalescer.handle(event), *handle); +// } +// } +// } diff --git a/serf-core/src/coalesce/user.rs b/serf-core/src/coalesce/user.rs index 0957fb0..8cf0ed6 100644 --- a/serf-core/src/coalesce/user.rs +++ b/serf-core/src/coalesce/user.rs @@ -97,141 +97,141 @@ where } } -#[cfg(all(test, feature = "test"))] -mod tests { - use std::net::SocketAddr; - - use agnostic_lite::tokio::TokioRuntime; - use memberlist_core::transport::resolver::socket_addr::SocketAddrResolver; - - use crate::{ - DefaultDelegate, - event::{MemberEvent, MemberEventType}, - }; - - use super::*; - - type Transport = UnimplementedTransport< - SmolStr, - SocketAddrResolver, - Lpe, - TokioRuntime, - >; - - type Delegate = DefaultDelegate; - - #[tokio::test] - async fn test_user_event_coalesce_basic() { - let (tx, rx) = async_channel::unbounded(); - let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let coalescer = UserEventCoalescer::::new(); - - let in_ = coalesced_event( - tx, - shutdown_rx, - Duration::from_millis(20), - Duration::from_millis(20), - coalescer, - ); - - let send = vec![ - UserEventMessage::default() - .with_name("foo".into()) - .with_cc(true) - .with_ltime(1.into()), - UserEventMessage::default() - .with_name("foo".into()) - .with_cc(true) - .with_ltime(2.into()), - UserEventMessage::default() - .with_name("bar".into()) - .with_cc(true) - .with_ltime(2.into()) - .with_payload("test1".into()), - UserEventMessage::default() - .with_name("bar".into()) - .with_cc(true) - .with_ltime(2.into()) - .with_payload("test2".into()), - ]; - - for event in send { - in_.send(CrateEvent::from(event)).await.unwrap(); - } - - let mut got_foo = false; - let mut got_bar1 = false; - let mut got_bar2 = false; - - loop { - futures::select! { - _ = TokioRuntime::sleep(Duration::from_millis(40)).fuse() => break, - event = rx.recv().fuse() => { - let event = event.unwrap(); - match event { - CrateEvent::User(e) => { - match e.name().as_str() { - "foo" => { - assert_eq!(e.ltime(), 2.into(), "bad ltime for foo"); - got_foo = true; - } - "bar" => { - assert_eq!(e.ltime(), 2.into(), "bad ltime for bar"); - if e.payload().eq("test1".as_bytes()) { - got_bar1 = true; - } - - if e.payload().eq("test2".as_bytes()) { - got_bar2 = true; - } - } - _ => unreachable!(), - } - } - _ => unreachable!(), - } - } - } - } - - assert!(got_foo && got_bar1 && got_bar2, "missing events"); - } - - #[test] - fn test_user_event_coalesce_pass_through() { - let cases = [ - (CrateEvent::from(UserEventMessage::default()), false), - ( - CrateEvent::from(UserEventMessage::default().with_cc(true)), - true, - ), - ( - CrateEvent::from(MemberEvent { - ty: MemberEventType::Join, - members: TinyVec::new().into(), - }), - false, - ), - ( - CrateEvent::from(MemberEvent { - ty: MemberEventType::Leave, - members: TinyVec::new().into(), - }), - false, - ), - ( - CrateEvent::from(MemberEvent { - ty: MemberEventType::Failed, - members: TinyVec::new().into(), - }), - false, - ), - ]; - - let coalescer = UserEventCoalescer::::new(); - - for (idx, (event, should_coalesce)) in cases.iter().enumerate() { - assert_eq!(coalescer.handle(event), *should_coalesce, "bad: {idx}"); - } - } -} +// #[cfg(all(test, feature = "test"))] +// mod tests { +// use std::net::SocketAddr; + +// use agnostic_lite::tokio::TokioRuntime; +// use memberlist_core::transport::resolver::socket_addr::SocketAddrResolver; + +// use crate::{ +// DefaultDelegate, +// event::{MemberEvent, MemberEventType}, +// }; + +// use super::*; + +// type Transport = UnimplementedTransport< +// SmolStr, +// SocketAddrResolver, +// Lpe, +// TokioRuntime, +// >; + +// type Delegate = DefaultDelegate; + +// #[tokio::test] +// async fn test_user_event_coalesce_basic() { +// let (tx, rx) = async_channel::unbounded(); +// let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); +// let coalescer = UserEventCoalescer::::new(); + +// let in_ = coalesced_event( +// tx, +// shutdown_rx, +// Duration::from_millis(20), +// Duration::from_millis(20), +// coalescer, +// ); + +// let send = vec![ +// UserEventMessage::default() +// .with_name("foo".into()) +// .with_cc(true) +// .with_ltime(1.into()), +// UserEventMessage::default() +// .with_name("foo".into()) +// .with_cc(true) +// .with_ltime(2.into()), +// UserEventMessage::default() +// .with_name("bar".into()) +// .with_cc(true) +// .with_ltime(2.into()) +// .with_payload("test1".into()), +// UserEventMessage::default() +// .with_name("bar".into()) +// .with_cc(true) +// .with_ltime(2.into()) +// .with_payload("test2".into()), +// ]; + +// for event in send { +// in_.send(CrateEvent::from(event)).await.unwrap(); +// } + +// let mut got_foo = false; +// let mut got_bar1 = false; +// let mut got_bar2 = false; + +// loop { +// futures::select! { +// _ = TokioRuntime::sleep(Duration::from_millis(40)).fuse() => break, +// event = rx.recv().fuse() => { +// let event = event.unwrap(); +// match event { +// CrateEvent::User(e) => { +// match e.name().as_str() { +// "foo" => { +// assert_eq!(e.ltime(), 2.into(), "bad ltime for foo"); +// got_foo = true; +// } +// "bar" => { +// assert_eq!(e.ltime(), 2.into(), "bad ltime for bar"); +// if e.payload().eq("test1".as_bytes()) { +// got_bar1 = true; +// } + +// if e.payload().eq("test2".as_bytes()) { +// got_bar2 = true; +// } +// } +// _ => unreachable!(), +// } +// } +// _ => unreachable!(), +// } +// } +// } +// } + +// assert!(got_foo && got_bar1 && got_bar2, "missing events"); +// } + +// #[test] +// fn test_user_event_coalesce_pass_through() { +// let cases = [ +// (CrateEvent::from(UserEventMessage::default()), false), +// ( +// CrateEvent::from(UserEventMessage::default().with_cc(true)), +// true, +// ), +// ( +// CrateEvent::from(MemberEvent { +// ty: MemberEventType::Join, +// members: TinyVec::new().into(), +// }), +// false, +// ), +// ( +// CrateEvent::from(MemberEvent { +// ty: MemberEventType::Leave, +// members: TinyVec::new().into(), +// }), +// false, +// ), +// ( +// CrateEvent::from(MemberEvent { +// ty: MemberEventType::Failed, +// members: TinyVec::new().into(), +// }), +// false, +// ), +// ]; + +// let coalescer = UserEventCoalescer::::new(); + +// for (idx, (event, should_coalesce)) in cases.iter().enumerate() { +// assert_eq!(coalescer.handle(event), *should_coalesce, "bad: {idx}"); +// } +// } +// } diff --git a/serf-core/src/delegate/composite.rs b/serf-core/src/delegate/composite.rs index 78cdee6..40ae2e8 100644 --- a/serf-core/src/delegate/composite.rs +++ b/serf-core/src/delegate/composite.rs @@ -1,29 +1,21 @@ use memberlist_core::{ CheapClone, + proto::TinyVec, transport::{Id, Node}, - types::TinyVec, }; use serf_proto::MessageType; -use crate::{ - coordinate::Coordinate, - types::{AsMessageRef, Filter, Member, SerfMessage, Tags}, -}; +use crate::{coordinate::Coordinate, types::Member}; use super::{ - DefaultMergeDelegate, Delegate, LpeTransfromDelegate, MergeDelegate, NoopReconnectDelegate, - ReconnectDelegate, , + DefaultMergeDelegate, Delegate, MergeDelegate, NoopReconnectDelegate, ReconnectDelegate, }; /// `CompositeDelegate` is a helpful struct to split the [`Delegate`] into multiple small delegates, /// so that users do not need to implement full [`Delegate`] when they only want to custom some methods /// in the [`Delegate`]. -pub struct CompositeDelegate< - I, - A, - M = DefaultMergeDelegate, - R = NoopReconnectDelegate, -> { +pub struct CompositeDelegate, R = NoopReconnectDelegate> +{ merge: M, reconnect: R, _m: std::marker::PhantomData<(I, A)>, @@ -51,7 +43,7 @@ where M: MergeDelegate, { /// Set the [`MergeDelegate`] for the `CompositeDelegate`. - pub fn with_merge_delegate(self, merge: NM) -> CompositeDelegate { + pub fn with_merge_delegate(self, merge: NM) -> CompositeDelegate { CompositeDelegate { merge, reconnect: self.reconnect, @@ -62,7 +54,7 @@ where impl CompositeDelegate { /// Set the [`ReconnectDelegate`] for the `CompositeDelegate`. - pub fn with_reconnect_delegate(self, reconnect: NR) -> CompositeDelegate { + pub fn with_reconnect_delegate(self, reconnect: NR) -> CompositeDelegate { CompositeDelegate { reconnect, merge: self.merge, @@ -73,7 +65,7 @@ impl CompositeDelegate { impl MergeDelegate for CompositeDelegate where - I: Id, + I: Id + Send + Sync + 'static, A: CheapClone + Send + Sync + 'static, M: MergeDelegate, R: Send + Sync + 'static, @@ -94,7 +86,7 @@ where impl ReconnectDelegate for CompositeDelegate where - I: Id, + I: Id + Send + Sync + 'static, A: CheapClone + Send + Sync + 'static, M: Send + Sync + 'static, R: ReconnectDelegate, @@ -114,7 +106,7 @@ where impl Delegate for CompositeDelegate where - I: Id, + I: Id + Send + Sync + 'static, A: CheapClone + Send + Sync + 'static, M: MergeDelegate, R: ReconnectDelegate, diff --git a/serf-core/src/delegate/merge.rs b/serf-core/src/delegate/merge.rs index b2b2938..d1728fb 100644 --- a/serf-core/src/delegate/merge.rs +++ b/serf-core/src/delegate/merge.rs @@ -1,4 +1,4 @@ -use memberlist_core::{CheapClone, transport::Id, proto::TinyVec}; +use memberlist_core::{CheapClone, proto::TinyVec, transport::Id}; use std::future::Future; use crate::types::Member; @@ -39,7 +39,7 @@ impl Default for DefaultMergeDelegate { impl MergeDelegate for DefaultMergeDelegate where - I: Id, + I: Id + Send + Sync + 'static, A: CheapClone + Send + Sync + 'static, { type Error = std::convert::Infallible; diff --git a/serf-core/src/delegate/reconnect.rs b/serf-core/src/delegate/reconnect.rs index 4073052..98b3476 100644 --- a/serf-core/src/delegate/reconnect.rs +++ b/serf-core/src/delegate/reconnect.rs @@ -40,7 +40,7 @@ impl Copy for NoopReconnectDelegate {} impl ReconnectDelegate for NoopReconnectDelegate where - I: Id, + I: Id + Send + Sync + 'static, A: CheapClone + Send + Sync + 'static, { type Id = I; diff --git a/serf-core/src/error.rs b/serf-core/src/error.rs index 8e5d84d..5cc683e 100644 --- a/serf-core/src/error.rs +++ b/serf-core/src/error.rs @@ -1,9 +1,6 @@ -use std::collections::HashMap; +use std::sync::Arc; -use memberlist_core::{ - transport::{AddressResolver, MaybeResolvedAddress, Node, Transport}, - proto::{SmallVec, TinyVec}, -}; +use memberlist_core::{proto::TinyVec, transport::Transport}; use crate::{ delegate::Delegate, @@ -14,10 +11,10 @@ use crate::{ pub use crate::snapshot::SnapshotError; /// Error type for the serf crate. -#[derive(thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Returned when the underlyhing memberlist error @@ -29,35 +26,34 @@ where /// Returned when the relay error #[error(transparent)] Relay(#[from] RelayError), + /// Multiple errors + #[error("errors:\n{}", format_multiple_errors(.0))] + Multiple(Arc<[Self]>), } -impl core::fmt::Debug for Error +impl From for Error where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Memberlist(e) => write!(f, "{e:?}"), - Self::Serf(e) => write!(f, "{e:?}"), - Self::Relay(e) => write!(f, "{e:?}"), - } + fn from(value: SnapshotError) -> Self { + Self::Serf(SerfError::Snapshot(value)) } } -impl From for Error +impl From for Error where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { - fn from(value: SnapshotError) -> Self { - Self::Serf(SerfError::Snapshot(value)) + fn from(e: memberlist_core::proto::EncodeError) -> Self { + Self::Memberlist(e.into()) } } impl Error where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Create a query response too large error @@ -234,28 +230,28 @@ pub enum SerfError { pub struct RelayError( #[allow(clippy::type_complexity)] TinyVec<( - Member::ResolvedAddress>, + Member, memberlist_core::error::Error>, )>, ) where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport; impl From< TinyVec<( - Member::ResolvedAddress>, + Member, memberlist_core::error::Error>, )>, > for RelayError where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn from( value: TinyVec<( - Member::ResolvedAddress>, + Member, memberlist_core::error::Error>, )>, ) -> Self { @@ -265,7 +261,7 @@ where impl core::fmt::Display for RelayError where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -285,7 +281,7 @@ where impl core::fmt::Debug for RelayError where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -295,91 +291,20 @@ where impl std::error::Error for RelayError where - D: Delegate::ResolvedAddress>, - T: Transport, -{ -} - -/// `JoinError` is returned when join is partially/totally failed. -pub struct JoinError -where - D: Delegate::ResolvedAddress>, - T: Transport, -{ - pub(crate) joined: SmallVec::ResolvedAddress>>, - pub(crate) errors: HashMap>, Error>, - pub(crate) broadcast_error: Option>, -} - -impl JoinError -where - D: Delegate::ResolvedAddress>, - T: Transport, -{ - /// Returns the broadcast error that occurred during the join. - #[inline] - pub const fn broadcast_error(&self) -> Option<&Error> { - self.broadcast_error.as_ref() - } - - /// Returns the errors that occurred during the join. - #[inline] - pub const fn errors(&self) -> &HashMap>, Error> { - &self.errors - } - - /// Returns the nodes have successfully joined. - #[inline] - pub const fn joined( - &self, - ) -> &SmallVec::ResolvedAddress>> { - &self.joined - } - - /// Returns how many nodes have successfully joined. - #[inline] - pub fn num_joined(&self) -> usize { - self.joined.len() - } -} - -impl core::fmt::Debug for JoinError -where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self) - } -} - -impl core::fmt::Display for JoinError -where - D: Delegate::ResolvedAddress>, - T: Transport, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - if !self.joined.is_empty() { - writeln!(f, "Successes: {:?}", self.joined)?; - } - - if !self.errors.is_empty() { - writeln!(f, "Failures:")?; - for (address, err) in self.errors.iter() { - writeln!(f, "\t{}: {}", address, err)?; - } - } - - if let Some(err) = &self.broadcast_error { - writeln!(f, "Broadcast Error: {err}")?; - } - Ok(()) - } } -impl std::error::Error for JoinError +fn format_multiple_errors(errors: &[Error]) -> String where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { + errors + .iter() + .enumerate() + .map(|(i, err)| format!(" {}. {}", i + 1, err)) + .collect::>() + .join("\n") } diff --git a/serf-core/src/event.rs b/serf-core/src/event.rs index faeb747..0458f8b 100644 --- a/serf-core/src/event.rs +++ b/serf-core/src/event.rs @@ -14,18 +14,16 @@ pub(crate) use crate_event::*; use futures::Stream; use memberlist_core::{ CheapClone, - bytes::{BufMut, Bytes, BytesMut}, - transport::{AddressResolver, Transport}, + bytes::Bytes, proto::TinyVec, + transport::{AddressResolver, Transport}, }; -use serf_proto::{ - LamportTime, Member, MessageType, Node, QueryFlag, QueryResponseMessage, UserEventMessage, -}; +use serf_proto::{LamportTime, Member, Node, QueryFlag, QueryResponseMessage, UserEventMessage}; use smol_str::SmolStr; pub(crate) struct QueryContext where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { pub(crate) query_timeout: Duration, @@ -35,15 +33,14 @@ where impl QueryContext where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { - fn check_response_size(&self, resp: &[u8]) -> Result<(), Error> { - let resp_len = resp.len(); - if resp_len > self.this.inner.opts.query_response_size_limit { + fn check_response_size(&self, size: usize) -> Result<(), Error> { + if size > self.this.inner.opts.query_response_size_limit { Err(Error::query_response_too_large( self.this.inner.opts.query_response_size_limit, - resp_len, + size, )) } else { Ok(()) @@ -52,12 +49,12 @@ where async fn respond_with_message_and_response( &self, - respond_to: &::ResolvedAddress, + respond_to: &T::ResolvedAddress, relay_factor: u8, raw: Bytes, - resp: QueryResponseMessage::ResolvedAddress>, + resp: QueryResponseMessage, ) -> Result<(), Error> { - self.check_response_size(raw.as_ref())?; + self.check_response_size(raw.len())?; let mut mu = self.span.lock().await; @@ -86,7 +83,7 @@ where async fn respond( &self, - respond_to: &::ResolvedAddress, + respond_to: &T::ResolvedAddress, id: u32, ltime: LamportTime, relay_factor: u8, @@ -99,18 +96,9 @@ where flags: QueryFlag::empty(), payload: msg, }; - let expected_encoded_len = ::message_encoded_len(&resp); - let mut buf = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type byte - buf.put_u8(MessageType::QueryResponse as u8); - buf.resize(expected_encoded_len + 1, 0); - let len = ::encode_message(&resp, &mut buf[1..]) - .map_err(Error::transform_delegate)?; - debug_assert_eq!( - len, expected_encoded_len, - "expected encoded len {expected_encoded_len} is not match the actual encoded len {len}" - ); + let buf = serf_proto::Encodable::encode_to_bytes(&resp)?; self - .respond_with_message_and_response(respond_to, relay_factor, buf.freeze(), resp) + .respond_with_message_and_response(respond_to, relay_factor, buf, resp) .await } } @@ -118,7 +106,7 @@ where /// Query event pub struct QueryEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { pub(crate) ltime: LamportTime, @@ -128,14 +116,14 @@ where pub(crate) ctx: Arc>, pub(crate) id: u32, /// source node - pub(crate) from: Node::ResolvedAddress>, + pub(crate) from: Node, /// Number of duplicate responses to relay back to sender pub(crate) relay_factor: u8, } impl QueryEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Returns the lamport time of the query @@ -164,14 +152,14 @@ where /// Returns the source node of the query #[inline] - pub const fn from(&self) -> &Node::ResolvedAddress> { + pub const fn from(&self) -> &Node { &self.from } } impl PartialEq for QueryEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn eq(&self, other: &Self) -> bool { @@ -186,7 +174,7 @@ where impl AsRef> for QueryEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn as_ref(&self) -> &QueryEvent { @@ -196,7 +184,7 @@ where impl Clone for QueryEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn clone(&self) -> Self { @@ -214,7 +202,7 @@ where impl core::fmt::Display for QueryEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -224,14 +212,14 @@ where impl QueryEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { #[cfg(feature = "encryption")] pub(crate) fn create_response( &self, buf: Bytes, - ) -> QueryResponseMessage::ResolvedAddress> { + ) -> QueryResponseMessage { QueryResponseMessage { ltime: self.ltime, id: self.id, @@ -242,15 +230,15 @@ where } #[cfg(feature = "encryption")] - pub(crate) fn check_response_size(&self, resp: &[u8]) -> Result<(), Error> { - self.ctx.check_response_size(resp) + pub(crate) fn check_response_size(&self, size: usize) -> Result<(), Error> { + self.ctx.check_response_size(size) } #[cfg(feature = "encryption")] pub(crate) async fn respond_with_message_and_response( &self, raw: Bytes, - resp: QueryResponseMessage::ResolvedAddress>, + resp: QueryResponseMessage, ) -> Result<(), Error> { self .ctx @@ -383,11 +371,11 @@ impl From> for (MemberEventType, Arc where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Member related events - Member(MemberEvent::ResolvedAddress>), + Member(MemberEvent), /// User events User(UserEventMessage), /// Query events @@ -396,7 +384,7 @@ where impl Clone for Event where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn clone(&self) -> Self { @@ -412,7 +400,7 @@ where #[derive(Debug)] pub struct EventProducer where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { pub(crate) tx: Sender>, @@ -420,7 +408,7 @@ where impl EventProducer where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Creates a bounded producer and subscriber. @@ -446,7 +434,7 @@ where #[derive(Debug)] pub struct EventSubscriber where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { #[pin] @@ -455,7 +443,7 @@ where impl EventSubscriber where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Receives a event from the subscriber. @@ -509,7 +497,7 @@ where impl Stream for EventSubscriber where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Item = Event; diff --git a/serf-core/src/event/crate_event.rs b/serf-core/src/event/crate_event.rs index df71d7b..1423abf 100644 --- a/serf-core/src/event/crate_event.rs +++ b/serf-core/src/event/crate_event.rs @@ -1,21 +1,23 @@ +use memberlist_core::proto::{Data, DecodeError}; use serf_proto::QueryMessage; use super::*; -pub(crate) trait QueryMessageExt { - fn decode_internal_query( - &self, - ) -> Option, T::Error>>; +pub(crate) trait QueryMessageExt { + fn decode_internal_query(&self) -> Option, DecodeError>>; } -impl QueryMessageExt for QueryMessage { - fn decode_internal_query( - &self, - ) -> Option, T::Error>> { +impl QueryMessageExt for QueryMessage +where + I: Data, +{ + fn decode_internal_query(&self) -> Option, DecodeError>> { Some(Ok(match self.name().as_str() { INTERNAL_PING => InternalQueryEvent::Ping, INTERNAL_CONFLICT => { - return Some(T::decode_id(&self.payload).map(|(_, id)| InternalQueryEvent::Conflict(id))); + return Some( + ::decode(&self.payload).map(|(_, id)| InternalQueryEvent::Conflict(id)), + ); } #[cfg(feature = "encryption")] INTERNAL_INSTALL_KEY => InternalQueryEvent::InstallKey, diff --git a/serf-core/src/key_manager.rs b/serf-core/src/key_manager.rs index 10ab920..a67f139 100644 --- a/serf-core/src/key_manager.rs +++ b/serf-core/src/key_manager.rs @@ -6,9 +6,9 @@ use futures::StreamExt; use memberlist_core::{ CheapClone, bytes::{BufMut, BytesMut}, + proto::SecretKey, tracing, transport::{AddressResolver, Transport}, - types::SecretKey, }; use smol_str::SmolStr; @@ -19,10 +19,10 @@ use crate::event::{ use super::{ Serf, - delegate::{Delegate, }, + delegate::Delegate, error::Error, serf::{NodeResponse, QueryResponse}, - types::{KeyRequestMessage, MessageType, SerfMessage}, + types::{KeyRequestMessage, MessageType}, }; /// KeyResponse is used to relay a query for a list of all keys in use. @@ -186,19 +186,7 @@ where event: InternalQueryEvent, ) -> Result, Error> { let kr = KeyRequestMessage { key }; - let expected_encoded_len = ::message_encoded_len(&kr); - let mut buf = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type - buf.put_u8(MessageType::KeyRequest as u8); - buf.resize(expected_encoded_len + 1, 0); - // Encode the query request - let len = ::encode_message(&kr, &mut buf[1..]) - .map_err(Error::transform_delegate)?; - - debug_assert_eq!( - len, expected_encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_encoded_len, len - ); + let buf = serf_proto::Encodable::encode_to_bytes(&kr)?; let serf = self.serf.get().unwrap(); let mut q_param = serf.default_query_param().await; @@ -206,7 +194,7 @@ where q_param.relay_factor = opts.relay_factor; } let qresp: QueryResponse::ResolvedAddress> = serf - .internal_query(SmolStr::new(ty), buf.freeze(), Some(q_param), event) + .internal_query(SmolStr::new(ty), buf, Some(q_param), event) .await?; // Handle the response stream and populate the KeyResponse @@ -249,7 +237,7 @@ where resp.num_resp += 1; // Decode the response - if r.payload.is_empty() || r.payload[0] != MessageType::KeyResponse as u8 { + if r.payload.is_empty() || r.payload[0] != u8::from(MessageType::KeyResponse) { resp.messages.insert( r.from.id().cheap_clone(), SmolStr::new(format!( @@ -265,30 +253,16 @@ where continue; } - let node_response = - match ::decode_message(MessageType::KeyResponse, &r.payload[1..]) { - Ok((_, nr)) => match nr { - SerfMessage::KeyResponse(kr) => kr, - msg => { - resp.messages.insert( - r.from.id().cheap_clone(), - SmolStr::new(format!( - "Invalid key query response type: {:?}", - msg.ty().as_str() - )), - ); - resp.num_err += 1; - - if resp.num_resp == resp.num_nodes { - return resp; - } - continue; - } - }, - Err(e) => { + let node_response = match decode_message(MessageType::KeyResponse, &r.payload[1..]) { + Ok((_, nr)) => match nr { + SerfMessage::KeyResponse(kr) => kr, + msg => { resp.messages.insert( r.from.id().cheap_clone(), - SmolStr::new(format!("Failed to decode key query response: {:?}", e)), + SmolStr::new(format!( + "Invalid key query response type: {:?}", + msg.ty().as_str() + )), ); resp.num_err += 1; @@ -297,7 +271,20 @@ where } continue; } - }; + }, + Err(e) => { + resp.messages.insert( + r.from.id().cheap_clone(), + SmolStr::new(format!("Failed to decode key query response: {:?}", e)), + ); + resp.num_err += 1; + + if resp.num_resp == resp.num_nodes { + return resp; + } + continue; + } + }; if !node_response.result { resp diff --git a/serf-core/src/lib.rs b/serf-core/src/lib.rs index cc9cba2..5255732 100644 --- a/serf-core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -1,7 +1,7 @@ #![doc = include_str!("../../README.md")] #![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] #![forbid(unsafe_code)] -#![deny(warnings, missing_docs)] +// #![deny(warnings, missing_docs)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, allow(unused_attributes))] @@ -52,7 +52,7 @@ pub mod tests { pub use memberlist_core::tests::{AnyError, next_socket_addr_v4, next_socket_addr_v6}; pub use paste; - pub use super::serf::base::tests::{serf::*, *}; + // pub use super::serf::base::tests::{serf::*, *}; /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) #[cfg(any(feature = "test", test))] diff --git a/serf-core/src/serf.rs b/serf-core/src/serf.rs index 7a6b53e..d63e95d 100644 --- a/serf-core/src/serf.rs +++ b/serf-core/src/serf.rs @@ -9,9 +9,9 @@ use futures::stream::FuturesUnordered; use memberlist_core::{ Memberlist, agnostic_lite::{AsyncSpawner, RuntimeLite}, + proto::MediumVec, queue::TransmitLimitedQueue, transport::{AddressResolver, Transport}, - proto::MediumVec, }; use super::{ diff --git a/serf-core/src/serf/api.rs b/serf-core/src/serf/api.rs index d40dd94..efcdf19 100644 --- a/serf-core/src/serf/api.rs +++ b/serf-core/src/serf/api.rs @@ -2,15 +2,18 @@ use std::sync::atomic::Ordering; use futures::{FutureExt, StreamExt}; use memberlist_core::{ - bytes::{BufMut, Bytes, BytesMut}, proto::Data, tracing, transport::{MaybeResolvedAddress, Node}, types::{Meta, OneOrMore, SmallVec}, CheapClone + CheapClone, + bytes::Bytes, + proto::{Data, Meta, OneOrMore, SmallVec}, + tracing, + transport::{MaybeResolvedAddress, Node}, }; use smol_str::SmolStr; use crate::{ - delegate::, - error::{Error, JoinError}, + error::Error, event::EventProducer, - types::{LeaveMessage, Member, MessageType, SerfMessage, Tags, UserEventMessage}, + types::{LeaveMessage, Member, Tags, UserEventMessage}, }; use super::*; @@ -270,7 +273,7 @@ where }; // Start broadcasting the event - let len = ::message_encoded_len(&msg); + let len = serf_proto::Encodable::encoded_len(&msg); // Check the size after encoding to be sure again that // we're not attempting to send over the specified size limit. @@ -282,17 +285,7 @@ where return Err(Error::raw_user_event_too_large(len)); } - let mut raw = BytesMut::with_capacity(len + 1); // + 1 for message type byte - raw.put_u8(MessageType::UserEvent as u8); - raw.resize(len + 1, 0); - - let actual_encoded_len = ::encode_message(&msg, &mut raw[1..]) - .map_err(Error::transform_delegate)?; - debug_assert_eq!( - actual_encoded_len, len, - "expected encoded len {} mismatch the actual encoded len {}", - len, actual_encoded_len - ); + let raw = serf_proto::Encodable::encode_to_bytes(&msg)?; self.inner.event_clock.increment(); @@ -303,7 +296,7 @@ where .inner .event_broadcasts .queue_broadcast(SerfBroadcast { - msg: raw.freeze(), + msg: raw, notify_tx: None, }) .await; @@ -382,20 +375,13 @@ where existing: impl Iterator>>, ignore_old: bool, ) -> Result< - SmallVec::ResolvedAddress>>, - JoinError, + SmallVec>, + (SmallVec>, Error), > { // Do a quick state check let current_state = self.state(); if current_state != SerfState::Alive { - return Err(JoinError { - joined: SmallVec::new(), - errors: existing - .into_iter() - .map(|node| (node, Error::bad_join_status(current_state))) - .collect(), - broadcast_error: None, - }); + return Err((SmallVec::new(), Error::bad_join_status(current_state))); } // Hold the joinLock, this is to make eventJoinIgnore safe @@ -413,51 +399,28 @@ where // Start broadcasting the update if let Err(e) = self.broadcast_join(self.inner.clock.time()).await { self.inner.event_join_ignore.store(false, Ordering::SeqCst); - return Err(JoinError { - joined, - errors: Default::default(), - broadcast_error: Some(e), - }); + return Err((joined, e)); } self.inner.event_join_ignore.store(false, Ordering::SeqCst); Ok(joined) } - Err(e) => { - let (joined, errors) = e.into(); + Err((joined, err)) => { // If we joined any nodes, broadcast the join message if !joined.is_empty() { // Start broadcasting the update if let Err(e) = self.broadcast_join(self.inner.clock.time()).await { self.inner.event_join_ignore.store(false, Ordering::SeqCst); - return Err(JoinError { + return Err(( joined, - errors: errors - .into_iter() - .map(|(addr, err)| (addr, err.into())) - .collect(), - broadcast_error: Some(e), - }); + Error::Multiple(std::sync::Arc::from_iter([err.into(), e])), + )); } self.inner.event_join_ignore.store(false, Ordering::SeqCst); - Err(JoinError { - joined, - errors: errors - .into_iter() - .map(|(addr, err)| (addr, err.into())) - .collect(), - broadcast_error: None, - }) + Err((joined, Error::from(err))) } else { self.inner.event_join_ignore.store(false, Ordering::SeqCst); - Err(JoinError { - joined, - errors: errors - .into_iter() - .map(|(addr, err)| (addr, err.into())) - .collect(), - broadcast_error: None, - }) + Err((joined, Error::from(err))) } } } @@ -498,12 +461,11 @@ where // Process the leave locally self.handle_node_leave_intent(&msg).await; - let msg = SerfMessage::Leave(msg); - // Only broadcast the leave message if there is at least one // other node alive. if self.has_alive_members().await { let (notify_tx, notify_rx) = async_channel::bounded(1); + let msg = serf_proto::Encodable::encode_to_bytes(&msg)?; self.broadcast(msg, Some(notify_tx)).await?; futures::select! { @@ -632,7 +594,6 @@ where } #[viewit::viewit(vis_all = "", getters(vis_all = "pub", prefix = "get"), setters(skip))] -#[cfg_attr(feature = "async-graphql", derive(async_graphql::SimpleObject))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Stats { members: usize, diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index 248d57c..c17aac4 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -3,12 +3,11 @@ use std::time::Duration; use futures::{FutureExt, StreamExt}; use memberlist_core::{ CheapClone, - agnostic_lite::Detach, bytes::{BufMut, Bytes, BytesMut}, delegate::EventDelegate, + proto::{Meta, NodeState, OneOrMore, TinyVec}, tracing, transport::{MaybeResolvedAddress, Node}, - types::{Meta, NodeState, OneOrMore, TinyVec}, }; use rand::{Rng, SeedableRng}; use smol_str::SmolStr; @@ -17,14 +16,13 @@ use crate::{ QueueOptions, coalesce::{MemberEventCoalescer, UserEventCoalescer, coalesced_event}, coordinate::CoordinateOptions, - delegate::, error::Error, event::{InternalQueryEvent, MemberEvent, MemberEventType, QueryContext, QueryEvent}, snapshot::{Snapshot, open_and_replay_snapshot}, types::{ DelegateVersion, Epoch, JoinMessage, LeaveMessage, Member, MemberState, MemberStatus, MemberlistDelegateVersion, MemberlistProtocolVersion, MessageType, NodeIntent, ProtocolVersion, - QueryFlag, QueryMessage, QueryResponseMessage, SerfMessage, UserEvent, UserEventMessage, + QueryFlag, QueryMessage, QueryResponseMessage, UserEvent, UserEventMessage, }, }; @@ -32,10 +30,10 @@ use self::internal_query::SerfQueries; use super::*; -/// Re-export the unit tests -#[cfg(feature = "test")] -#[cfg_attr(docsrs, doc(cfg(feature = "test")))] -pub mod tests; +// /// Re-export the unit tests +// #[cfg(feature = "test")] +// #[cfg_attr(docsrs, doc(cfg(feature = "test")))] +// pub mod tests; impl Serf where @@ -74,7 +72,7 @@ where { let tags = opts.tags.load(); if !tags.as_ref().is_empty() { - let len = ::tags_encoded_len(&tags); + let len = tags_encoded_len(&tags); if len > Meta::MAX_SIZE { return Err(Error::tags_too_large(len)); } @@ -128,7 +126,7 @@ where // Try access the snapshot let (old_clock, old_event_clock, old_query_clock, event_tx, alive_nodes, handle) = if let Some(sp) = opts.snapshot_path.as_ref() { - let rs = open_and_replay_snapshot::<_, _, D, _>(sp, opts.rejoin_after_leave)?; + let rs = open_and_replay_snapshot::<_, _, D>(sp, opts.rejoin_after_leave)?; let old_clock = rs.last_clock; let old_event_clock = rs.last_event_clock; let old_query_clock = rs.last_query_clock; @@ -362,29 +360,13 @@ where /// when the broadcast is sent. pub(crate) async fn broadcast( &self, - msg: SerfMessage::ResolvedAddress>, + msg: Bytes, notify_tx: Option>, ) -> Result<(), Error> { - let ty = MessageType::from(&msg); - let expected_encoded_len = ::message_encoded_len(&msg); - let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // + 1 for message type byte - raw.put_u8(ty as u8); - raw.resize(expected_encoded_len + 1, 0); - let len = ::encode_message(&msg, &mut raw[1..]) - .map_err(Error::transform_delegate)?; - debug_assert_eq!( - len, expected_encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_encoded_len, len - ); - self .inner .broadcasts - .queue_broadcast(SerfBroadcast { - msg: raw.into(), - notify_tx, - }) + .queue_broadcast(SerfBroadcast { msg, notify_tx }) .await; Ok(()) } @@ -423,7 +405,6 @@ where if let Some(keyring) = self.inner.memberlist.keyring() { let encoded_keys = keyring .keys() - .await .map(|k| general_purpose::STANDARD.encode(k)) .collect::>(); @@ -669,7 +650,7 @@ where let num_alive = (mu.states.len() - num_failed - mu.left_members.len()).max(1); let prob = num_failed as f32 / num_alive as f32; - let r: f32 = rng.gen(); + let r: f32 = rng.random(); if r > prob { tracing::debug!("serf: forgoing reconnect for random throttling"); continue; @@ -888,9 +869,7 @@ where let local = self.inner.memberlist.advertise_node(); // Encode the filters - let filters = params - .encode_filters::() - .map_err(Error::transform_delegate)?; + let filters = params.encode_filters::()?; // Setup the flags let flags = if params.request_ack { @@ -913,23 +892,14 @@ where }; // Encode the query - let len = ::message_encoded_len(&q); + let len = serf_proto::Encodable::encoded_len(&q); // Check the size if len > self.inner.opts.query_size_limit { return Err(Error::query_too_large(len)); } - let mut raw = BytesMut::with_capacity(len + 1); // + 1 for message type byte - raw.put_u8(MessageType::Query as u8); - raw.resize(len + 1, 0); - let actual_encoded_len = ::encode_message(&q, &mut raw[1..]) - .map_err(Error::transform_delegate)?; - debug_assert_eq!( - actual_encoded_len, len, - "expected encoded len {} mismatch the actual encoded len {}", - len, actual_encoded_len - ); + let raw = serf_proto::Encodable::encode_to_bytes(&q)?; // Register QueryResponse to track acks and responses let resp = QueryResponse::from_query(&q, self.inner.memberlist.num_online_members().await); @@ -945,7 +915,7 @@ where .inner .query_broadcasts .queue_broadcast(SerfBroadcast { - msg: raw.freeze(), + msg: raw, notify_tx: None, }) .await; @@ -1067,24 +1037,9 @@ where payload: Bytes::new(), }; - let expected_encoded_len = ::message_encoded_len(&ack); - let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // + 1 for message type byte - raw.put_u8(MessageType::QueryResponse as u8); - raw.resize(expected_encoded_len + 1, 0); - - match ::encode_message(&ack, &mut raw[1..]) { - Ok(len) => { - debug_assert_eq!( - len, expected_encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_encoded_len, len - ); - if let Err(e) = self - .inner - .memberlist - .send(q.from().address(), raw.freeze()) - .await - { + match serf_proto::Encodable::encode_to_bytes(&ack) { + Ok(raw) => { + if let Err(e) = self.inner.memberlist.send(q.from().address(), raw).await { tracing::error!(err=%e, "serf: failed to send ack"); } @@ -1181,7 +1136,7 @@ where let node = n.node(); let tags = if !n.meta().is_empty() { - match ::decode_tags(n.meta()) { + match decode_tags(n.meta()) { Ok((readed, tags)) => { tracing::trace!(read = %readed, tags=?tags, "serf: decode tags successfully"); tags @@ -1544,7 +1499,7 @@ where &self, n: Arc::ResolvedAddress>>, ) { - let tags = match ::decode_tags(n.meta()) { + let tags = match decode_tags(n.meta()) { Ok((readed, tags)) => { tracing::trace!(read = %readed, tags=?tags, "serf: decode tags successfully"); tags @@ -1662,10 +1617,10 @@ where // Get the local node let local_id = self.inner.memberlist.local_id(); let local_advertise_addr = self.inner.memberlist.advertise_address(); - let encoded_id_len = ::id_encoded_len(local_id); + let encoded_id_len = id_encoded_len(local_id); let mut payload = vec![0u8; encoded_id_len]; - if let Err(e) = ::encode_id(local_id, &mut payload) { + if let Err(e) = encode_id(local_id, &mut payload) { tracing::error!(err=%e, "serf: failed to encode local id"); return; } @@ -1699,8 +1654,7 @@ where continue; } - match ::decode_message(MessageType::ConflictResponse, &r.payload[1..]) - { + match decode_message(MessageType::ConflictResponse, &r.payload[1..]) { Ok((_, decoded)) => { match decoded { SerfMessage::ConflictResponse(member) => { diff --git a/serf-core/src/serf/base/tests.rs b/serf-core/src/serf/base/tests.rs index f7c4564..8073112 100644 --- a/serf-core/src/serf/base/tests.rs +++ b/serf-core/src/serf/base/tests.rs @@ -364,16 +364,11 @@ pub async fn estimate_max_keys_in_list_key_response_factor( let mut found = 0; for i in (0..=resp.keys.len()).rev() { - let encoded_len = as >::message_encoded_len(&resp); - let mut dst = vec![0; encoded_len]; - as >::encode_message(&resp, &mut dst).unwrap(); + let dst = serf_proto::Encodable::encode_to_bytes(&resp).unwrap(); - let qresp = query.create_response(dst.into()); - let encoded_len = as >::message_encoded_len(&qresp); - let mut dst = vec![0; encoded_len]; - as >::encode_message(&qresp, &mut dst).unwrap(); - - if query.check_response_size(&dst).is_err() { + let qresp = query.create_response(dst); + let dst = serf_proto::Encodable::encode_to_bytes(&qresp).unwrap(); + if query.check_response_size(dst.len()).is_err() { resp.keys.truncate(i); continue; } diff --git a/serf-core/src/serf/base/tests/serf/delegate.rs b/serf-core/src/serf/base/tests/serf/delegate.rs index c81f10a..aa50fee 100644 --- a/serf-core/src/serf/base/tests/serf/delegate.rs +++ b/serf-core/src/serf/base/tests/serf/delegate.rs @@ -14,7 +14,7 @@ where .unwrap(); let meta = s.inner.memberlist.delegate().unwrap().node_meta(32).await; - let (_, tags) = as >::decode_tags(&meta).unwrap(); + let (_, tags) = decode_tags(&meta).unwrap(); assert_eq!(tags.get("role"), Some(&SmolStr::new("test"))); s.shutdown().await.unwrap(); @@ -82,7 +82,7 @@ where // Attempt a decode let (_, pp) = - as >::decode_message(MessageType::PushPull, &buf[1..]) + decode_message(MessageType::PushPull, &buf[1..]) .unwrap(); let SerfMessage::PushPull(pp) = pp else { @@ -137,22 +137,20 @@ where .collect(), left_members: ["foo".into()].into_iter().collect(), event_ltime: 50.into(), - events: TinyVec::from(Some(UserEvents { + events: TinyVec::from(UserEvents { ltime: 45.into(), events: OneOrMore::from(UserEvent { name: "test".into(), payload: Bytes::new(), }), - })), + }), query_ltime: 100.into(), }; - - let mut buf = vec![0; as >::message_encoded_len(&pp) + 1]; - buf[0] = MessageType::PushPull as u8; - as >::encode_message(&pp, &mut buf[1..]).unwrap(); + + let buf = serf_proto::Encodable::encode_to_bytes(&pp).unwrap(); // Merge in fake state - d.merge_remote_state(buf.into(), false).await; + d.merge_remote_state(buf, false).await; // Verify lamport assert_eq!(s.inner.clock.time(), 42.into(), "bad lamport clock"); diff --git a/serf-core/src/serf/base/tests/serf/event.rs b/serf-core/src/serf/base/tests/serf/event.rs index 0a36144..c4dae30 100644 --- a/serf-core/src/serf/base/tests/serf/event.rs +++ b/serf-core/src/serf/base/tests/serf/event.rs @@ -571,15 +571,15 @@ where assert_eq!(filters.len(), 3); let (_, node_filt) = - as >::decode_filter(&filters[0]).unwrap(); + decode_filter(&filters[0]).unwrap(); assert_eq!(node_filt.ty(), FilterType::Id); let (_, tag_filt) = - as >::decode_filter(&filters[1]).unwrap(); + decode_filter(&filters[1]).unwrap(); assert_eq!(tag_filt.ty(), FilterType::Tag); let (_, tag_filt) = - as >::decode_filter(&filters[2]).unwrap(); + decode_filter(&filters[2]).unwrap(); assert_eq!(tag_filt.ty(), FilterType::Tag); } diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index 0ca379f..bfefe0c 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -1,13 +1,13 @@ use crate::{ Serf, broadcast::SerfBroadcast, - delegate::{Delegate, }, - error::{SerfDelegateError, SerfError}, + delegate::Delegate, + error::SerfError, event::QueryMessageExt, types::{ DelegateVersion, JoinMessage, LamportTime, LeaveMessage, Member, MemberStatus, MemberlistDelegateVersion, MemberlistProtocolVersion, MessageType, ProtocolVersion, - PushPullMessageRef, SerfMessage, UserEventMessage, + PushPullMessageBorrow, SerfMessage, UserEventMessage, }, }; @@ -22,9 +22,9 @@ use memberlist_core::{ AliveDelegate, ConflictDelegate, Delegate as MemberlistDelegate, EventDelegate, MergeDelegate as MemberlistMergeDelegate, NodeDelegate, PingDelegate, }, + proto::{Meta, NodeState, SmallVec, State, TinyVec}, tracing, transport::{AddressResolver, Transport}, - types::{Meta, NodeState, SmallVec, State, TinyVec}, }; use serf_proto::Tags; @@ -117,7 +117,7 @@ where let tags = self.tags.load(); match tags.is_empty() { false => { - let encoded_len = ::tags_encoded_len(&tags); + let encoded_len = tags_encoded_len(&tags); let limit = limit.min(Meta::MAX_SIZE); if encoded_len > limit { panic!( @@ -127,7 +127,7 @@ where } let mut role_bytes = vec![0; encoded_len]; - match ::encode_tags(&tags, &mut role_bytes) { + match encode_tags(&tags, &mut role_bytes) { Ok(len) => { debug_assert_eq!( len, encoded_len, @@ -190,7 +190,7 @@ where } match ty { - MessageType::Leave => match ::decode_message(ty, &msg[1..]) { + MessageType::Leave => match decode_message(ty, &msg[1..]) { Ok((_, l)) => { if let SerfMessage::Leave(l) = &l { tracing::debug!("serf: leave message: {}", l.id()); @@ -203,7 +203,7 @@ where tracing::warn!(err=%e, "serf: failed to decode message"); } }, - MessageType::Join => match ::decode_message(ty, &msg[1..]) { + MessageType::Join => match decode_message(ty, &msg[1..]) { Ok((_, j)) => { if let SerfMessage::Join(j) = &j { tracing::debug!("serf: join message: {}", j.id()); @@ -216,7 +216,7 @@ where tracing::warn!(err=%e, "serf: failed to decode message"); } }, - MessageType::UserEvent => match ::decode_message(ty, &msg[1..]) { + MessageType::UserEvent => match decode_message(ty, &msg[1..]) { Ok((_, ue)) => { if let SerfMessage::UserEvent(ue) = ue { tracing::debug!("serf: user event message: {}", ue.name); @@ -230,7 +230,7 @@ where tracing::warn!(err=%e, "serf: failed to decode message"); } }, - MessageType::Query => match ::decode_message(ty, &msg[1..]) { + MessageType::Query => match decode_message(ty, &msg[1..]) { Ok((_, q)) => { if let SerfMessage::Query(q) = q { tracing::debug!("serf: query message: {}", q.name); @@ -255,22 +255,20 @@ where tracing::warn!(err=%e, "serf: failed to decode message"); } }, - MessageType::QueryResponse => { - match ::decode_message(ty, &msg[1..]) { - Ok((_, qr)) => { - if let SerfMessage::QueryResponse(qr) = qr { - tracing::debug!("serf: query response message: {}", qr.from); - this.handle_query_response(qr).await; - } else { - tracing::warn!("serf: receive unexpected message: {}", qr.ty().as_str()); - } - } - Err(e) => { - tracing::warn!(err=%e, "serf: failed to decode message"); + MessageType::QueryResponse => match decode_message(ty, &msg[1..]) { + Ok((_, qr)) => { + if let SerfMessage::QueryResponse(qr) = qr { + tracing::debug!("serf: query response message: {}", qr.from); + this.handle_query_response(qr).await; + } else { + tracing::warn!("serf: receive unexpected message: {}", qr.ty().as_str()); } } - } - MessageType::Relay => match ::decode_node(&msg[1..]) { + Err(e) => { + tracing::warn!(err=%e, "serf: failed to decode message"); + } + }, + MessageType::Relay => match decode_node(&msg[1..]) { Ok((consumed, n)) => { tracing::debug!("serf: relay message",); tracing::debug!("serf: relaying response to node: {}", n); @@ -389,7 +387,7 @@ where .iter() .map(|v| v.member.node().id().cheap_clone()) .collect::>(); - let pp = PushPullMessageRef { + let pp = PushPullMessageBorrow { ltime: this.inner.clock.time(), status_ltimes: &status_ltimes, left_members: &left_members, @@ -399,19 +397,8 @@ where }; drop(members); - let expected_encoded_len = ::message_encoded_len(pp); - let mut buf = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type byte - buf.put_u8(MessageType::PushPull as u8); - buf.resize(expected_encoded_len + 1, 0); - match ::encode_message(pp, &mut buf[1..]) { - Ok(encoded_len) => { - debug_assert_eq!( - expected_encoded_len, encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_encoded_len, encoded_len - ); - buf.freeze() - } + match serf_proto::Encodable::encode_to_bytes(&pp) { + Ok(buf) => buf, Err(e) => { tracing::error!(err=%e, "serf: failed to encode local state"); Bytes::new() @@ -449,7 +436,7 @@ where match ty { MessageType::PushPull => { - match ::decode_message(ty, &buf[1..]) { + match decode_message(ty, &buf[1..]) { Err(e) => { tracing::error!(err=%e, "serf: failed to decode remote state"); } @@ -677,9 +664,9 @@ where coord.portion.resize(len * 2, 0.0); // The rest of the message is the serialized coordinate. - let len = ::coordinate_encoded_len(&coord); + let len = coordinate_encoded_len(&coord); buf.resize(len + 1, 0); - if let Err(e) = ::encode_coordinate(&coord, &mut buf[1..]) { + if let Err(e) = encode_coordinate(&coord, &mut buf[1..]) { panic!("failed to encode coordinate: {}", e); } return buf.freeze(); @@ -687,12 +674,12 @@ where if let Some(c) = self.this().inner.coord_core.as_ref() { let coord = c.client.get_coordinate(); - let encoded_len = ::coordinate_encoded_len(&coord) + 1; + let encoded_len = coordinate_encoded_len(&coord) + 1; let mut buf = BytesMut::with_capacity(encoded_len); buf.put_u8(PING_VERSION); buf.resize(encoded_len, 0); - if let Err(e) = ::encode_coordinate(&coord, &mut buf[1..]) { + if let Err(e) = encode_coordinate(&coord, &mut buf[1..]) { tracing::error!(err=%e, "serf: failed to encode coordinate"); } buf.into() @@ -721,7 +708,7 @@ where } // Process the remainder of the message as a coordinate. - let coord = match ::decode_coordinate(&payload[1..]) { + let coord = match decode_coordinate(&payload[1..]) { Ok((readed, c)) => { tracing::trace!(read=%readed, coordinate=?c, "serf: decode coordinate successfully"); c @@ -815,7 +802,7 @@ where Ok(Member { node: node.node(), tags: if !node.meta().is_empty() { - ::decode_tags(node.meta()) + decode_tags(node.meta()) .map(|(read, tags)| { tracing::trace!(read=%read, tags=?tags, "serf: decode tags successfully"); Arc::new(tags) diff --git a/serf-core/src/serf/internal_query.rs b/serf-core/src/serf/internal_query.rs index 5f07c65..b00c645 100644 --- a/serf-core/src/serf/internal_query.rs +++ b/serf-core/src/serf/internal_query.rs @@ -8,16 +8,13 @@ use memberlist_core::{ }; use crate::{ - delegate::{Delegate, }, + delegate::Delegate, event::{CrateEvent, InternalQueryEvent, QueryEvent}, types::MessageType, }; #[cfg(feature = "encryption")] -use crate::{ - error::Error, - types::{KeyResponseMessage, SerfMessage}, -}; +use crate::{error::Error, types::KeyResponseMessage}; #[cfg(feature = "encryption")] use smol_str::SmolStr; @@ -151,20 +148,10 @@ where // Encode the response match out { Some(state) => { - let member = state.member(); - let expected_encoded_len = ::message_encoded_len(member); - let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type - raw.put_u8(MessageType::ConflictResponse as u8); - raw.resize(expected_encoded_len + 1, 0); - match ::encode_message(member, &mut raw[1..]) { - Ok(len) => { - debug_assert_eq!( - len, expected_encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_encoded_len, len - ); - - if let Err(e) = ev.respond(raw.freeze()).await { + let resp = serf_proto::ConflictResponseMessageBorrow::from(state.member()); + match serf_proto::Encodable::encode_to_bytes(&resp) { + Ok(raw) => { + if let Err(e) = ev.respond(raw).await { tracing::error!(target="serf", err=%e, "failed to respond to conflict query"); } } @@ -192,26 +179,25 @@ where async fn handle_install_key(ev: impl AsRef> + Send) { let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = - match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { - Ok((_, msg)) => match msg { - SerfMessage::KeyRequest(req) => req, - msg => { - tracing::error!( - err = "unexpected message type", - "serf: {}", - msg.ty().as_str() - ); - Self::send_key_response(q, &mut response).await; - return; - } - }, - Err(e) => { - tracing::error!(err=%e, "serf: failed to decode key request"); + let req = match decode_message(MessageType::KeyRequest, &q.payload[1..]) { + Ok((_, msg)) => match msg { + SerfMessage::KeyRequest(req) => req, + msg => { + tracing::error!( + err = "unexpected message type", + "serf: {}", + msg.ty().as_str() + ); Self::send_key_response(q, &mut response).await; return; } - }; + }, + Err(e) => { + tracing::error!(err=%e, "serf: failed to decode key request"); + Self::send_key_response(q, &mut response).await; + return; + } + }; if !q.ctx.this.encryption_enabled() { tracing::error!( @@ -227,7 +213,7 @@ where let kr = q.ctx.this.inner.memberlist.keyring(); match kr { Some(kr) => { - kr.insert(req.key.unwrap()).await; + kr.insert(req.key.unwrap()); if q.ctx.this.inner.opts.keyring_file.is_some() { if let Err(e) = q.ctx.this.write_keyring_file().await { tracing::error!(err=%e, "serf: failed to write keyring file"); @@ -256,26 +242,25 @@ where let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = - match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { - Ok((_, msg)) => match msg { - SerfMessage::KeyRequest(req) => req, - msg => { - tracing::error!( - err = "unexpected message type", - "serf: {}", - msg.ty().as_str() - ); - Self::send_key_response(q, &mut response).await; - return; - } - }, - Err(e) => { - tracing::error!(err=%e, "serf: failed to decode key request"); + let req = match decode_message(MessageType::KeyRequest, &q.payload[1..]) { + Ok((_, msg)) => match msg { + SerfMessage::KeyRequest(req) => req, + msg => { + tracing::error!( + err = "unexpected message type", + "serf: {}", + msg.ty().as_str() + ); Self::send_key_response(q, &mut response).await; return; } - }; + }, + Err(e) => { + tracing::error!(err=%e, "serf: failed to decode key request"); + Self::send_key_response(q, &mut response).await; + return; + } + }; if !q.ctx.this.encryption_enabled() { tracing::error!( @@ -291,7 +276,7 @@ where let kr = q.ctx.this.inner.memberlist.keyring(); match kr { Some(kr) => { - if let Err(e) = kr.use_key(&req.key.unwrap()).await { + if let Err(e) = kr.use_key(&req.key.unwrap()) { tracing::error!(err=%e, "serf: failed to change primary key"); response.message = SmolStr::new(e.to_string()); Self::send_key_response(q, &mut response).await; @@ -326,26 +311,25 @@ where let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = - match ::decode_message(MessageType::KeyRequest, &q.payload[1..]) { - Ok((_, msg)) => match msg { - SerfMessage::KeyRequest(req) => req, - msg => { - tracing::error!( - err = "unexpected message type", - "serf: {}", - msg.ty().as_str() - ); - Self::send_key_response(q, &mut response).await; - return; - } - }, - Err(e) => { - tracing::error!(target="serf", err=%e, "failed to decode key request"); + let req = match decode_message(MessageType::KeyRequest, &q.payload[1..]) { + Ok((_, msg)) => match msg { + SerfMessage::KeyRequest(req) => req, + msg => { + tracing::error!( + err = "unexpected message type", + "serf: {}", + msg.ty().as_str() + ); Self::send_key_response(q, &mut response).await; return; } - }; + }, + Err(e) => { + tracing::error!(target="serf", err=%e, "failed to decode key request"); + Self::send_key_response(q, &mut response).await; + return; + } + }; if !q.ctx.this.encryption_enabled() { tracing::error!( @@ -361,7 +345,7 @@ where let kr = q.ctx.this.inner.memberlist.keyring(); match kr { Some(kr) => { - if let Err(e) = kr.remove(&req.key.unwrap()).await { + if let Err(e) = kr.remove(&req.key.unwrap()) { tracing::error!(err=%e, "serf: failed to remove key"); response.message = SmolStr::new(e.to_string()); Self::send_key_response(q, &mut response).await; @@ -411,11 +395,11 @@ where let kr = q.ctx.this.inner.memberlist.keyring(); match kr { Some(kr) => { - for k in kr.keys().await { + for k in kr.keys() { response.keys.push(k); } - let primary_key = kr.primary_key().await; + let primary_key = kr.primary_key(); response.primary_key = Some(primary_key); response.result = true; Self::send_key_response(q, &mut response).await; @@ -450,43 +434,14 @@ where (q.ctx.this.inner.opts.query_response_size_limit / MIN_ENCODED_KEY_LENGTH).min(actual); for i in (0..=max_list_keys).rev() { - let expected_k_encoded_len = ::message_encoded_len(&*resp); - let mut raw = BytesMut::with_capacity(expected_k_encoded_len + 1); // +1 for the message type - raw.put_u8(MessageType::KeyResponse as u8); - raw.resize(expected_k_encoded_len + 1, 0); - - let len = ::encode_message(&*resp, &mut raw[1..]) - .map_err(Error::transform_delegate)?; - - debug_assert_eq!( - len, expected_k_encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_k_encoded_len, len - ); - let kraw = raw.freeze(); + let kraw = serf_proto::Encodable::encode_to_bytes(&*resp)?; // create response let qresp = q.create_response(kraw.clone()); - // encode response - let expected_encoded_len = ::message_encoded_len(&qresp); - let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type - raw.put_u8(MessageType::QueryResponse as u8); - raw.resize(expected_encoded_len + 1, 0); - - let len = ::encode_message(&qresp, &mut raw[1..]) - .map_err(Error::transform_delegate)?; - - debug_assert_eq!( - len, expected_encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_encoded_len, len - ); - - let qraw = raw.freeze(); - + let encoded_len = serf_proto::Encodable::encoded_len(&qresp); // Check the size limit - if q.check_response_size(&qraw).is_err() { + if q.check_response_size(encoded_len).is_err() { resp.keys.drain(i..); resp.message = SmolStr::new(format!( "truncated key list response, showing first {} of {} keys", @@ -495,6 +450,9 @@ where continue; } + // encode response + let qraw = serf_proto::Encodable::encode_to_bytes(&qresp)?; + if actual > i { tracing::warn!("serf: {}", resp.message); } @@ -506,7 +464,7 @@ where #[cfg(feature = "encryption")] async fn send_key_response(q: &QueryEvent, resp: &mut KeyResponseMessage) { match q.name.as_str() { - "serf-list-keys" => { + "_serf_list_keys" => { let (raw, qresp) = match Self::key_list_response_with_correct_size(q, resp) { Ok((raw, qresp)) => (raw, qresp), Err(e) => { @@ -519,28 +477,16 @@ where tracing::error!(target="serf", err=%e, "failed to respond to key query"); } } - _ => { - let expected_encoded_len = ::message_encoded_len(&*resp); - let mut raw = BytesMut::with_capacity(expected_encoded_len + 1); // +1 for the message type - raw.put_u8(MessageType::KeyResponse as u8); - raw.resize(expected_encoded_len + 1, 0); - match ::encode_message(&*resp, &mut raw[1..]) { - Ok(len) => { - debug_assert_eq!( - len, expected_encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_encoded_len, len - ); - - if let Err(e) = q.respond(raw.freeze()).await { - tracing::error!(target="serf", err=%e, "failed to respond to key query"); - } - } - Err(e) => { - tracing::error!(target="serf", err=%e, "failed to encode key response"); + _ => match serf_proto::Encodable::encode_to_bytes(&*resp) { + Ok(raw) => { + if let Err(e) = q.respond(raw).await { + tracing::error!(target="serf", err=%e, "failed to respond to key query"); } } - } + Err(e) => { + tracing::error!(target="serf", err=%e, "failed to encode key response"); + } + }, } } } diff --git a/serf-core/src/serf/query.rs b/serf-core/src/serf/query.rs index a8e59ae..19cad1f 100644 --- a/serf-core/src/serf/query.rs +++ b/serf-core/src/serf/query.rs @@ -10,13 +10,13 @@ use futures::{FutureExt, StreamExt, stream::FuturesUnordered}; use memberlist_core::{ CheapClone, bytes::{BufMut, Bytes, BytesMut}, + proto::{OneOrMore, SmallVec, TinyVec}, tracing, transport::{AddressResolver, Id, Node, Transport}, - types::{OneOrMore, SmallVec, TinyVec}, }; use crate::{ - delegate::{Delegate, }, + delegate::Delegate, error::Error, types::{ Filter, LamportTime, Member, MemberStatus, MessageType, QueryMessage, QueryResponseMessage, @@ -396,7 +396,7 @@ fn random_members(k: usize, mut members: SmallVec>) -> SmallV let mut i = 0; while i < rounds && i < n { - let j = rand::random::() % (n - i) + i; + let j = (rand::random::() as usize) % (n - i) + i; members.swap(i, j); i += 1; if i >= k && i >= rounds { @@ -444,7 +444,7 @@ where } // Decode the filter - let filter = match ::decode_filter(filter) { + let filter = match decode_filter(filter) { Ok((read, filter)) => { tracing::trace!(read=%read, filter=?filter, "serf: decoded filter successully"); filter @@ -461,12 +461,14 @@ where match filter { Filter::Id(nodes) => { // Check if we are being targeted - let found = nodes.iter().any(|n| n.eq(self.inner.memberlist.local_id())); + let found = nodes + .iter() + .any(|n: &T::Id| n.eq(self.inner.memberlist.local_id())); if !found { return false; } } - Filter::Tag { tag, expr: fexpr } => { + Filter::Tag(tag) => { // Check if we match this regex let tags = self.inner.opts.tags.load(); if !tags.is_empty() { @@ -530,35 +532,15 @@ where } // Prep the relay message, which is a wrapped version of the original. - // let relay_msg = SerfRelayMessage::new(node, SerfMessage::QueryResponse(resp)); - let expected_encoded_len = 1 - + ::node_encoded_len(&node) - + 1 - + ::message_encoded_len(&resp); // +1 for relay message type byte, +1 for the message type - if expected_encoded_len > self.inner.opts.query_response_size_limit { + let encoded_len = serf_proto::Encodable::encoded_len_with_relay(&resp); + if encoded_len > self.inner.opts.query_response_size_limit { return Err(Error::relayed_response_too_large( self.inner.opts.query_response_size_limit, )); } - let mut raw = BytesMut::with_capacity(expected_encoded_len + 1 + 1); // +1 for relay message type byte, +1 for the message type byte - raw.put_u8(MessageType::Relay as u8); - raw.resize(expected_encoded_len + 1 + 1, 0); - let mut encoded = 1; - encoded += ::encode_node(&node, &mut raw[encoded..]) - .map_err(Error::transform_delegate)?; - raw[encoded] = MessageType::QueryResponse as u8; - encoded += 1; - encoded += ::encode_message(&resp, &mut raw[encoded..]) - .map_err(Error::transform_delegate)?; - - debug_assert_eq!( - encoded, expected_encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - expected_encoded_len, encoded - ); - - let raw = raw.freeze(); + let raw = serf_proto::Encodable::encode_relay_to_bytes(&resp)?; + // Relay to a random set of peers. let relay_members = random_members(relay_factor as usize, members); diff --git a/serf-core/src/snapshot.rs b/serf-core/src/snapshot.rs index bdb9e04..8010630 100644 --- a/serf-core/src/snapshot.rs +++ b/serf-core/src/snapshot.rs @@ -18,9 +18,9 @@ use memberlist_core::{ CheapClone, agnostic_lite::{AsyncSpawner, RuntimeLite}, bytes::{BufMut, BytesMut}, + proto::{Data, TinyVec}, tracing, transport::{AddressResolver, Id, MaybeResolvedAddress, Node, Transport}, - proto::TinyVec, }; use rand::seq::SliceRandom; use serf_proto::UserEventMessage; @@ -160,19 +160,19 @@ const MAX_INLINED_BYTES: usize = 64; macro_rules! encode { ($w:ident.$node: ident::$status: ident) => {{ let node = $node.as_ref(); - let encoded_node_len = T::node_encoded_len(node); + let encoded_node_len = node.encoded_len(); let encoded_len = 4 + 1 + encoded_node_len; if encoded_len <= MAX_INLINED_BYTES { let mut buf = [0u8; MAX_INLINED_BYTES]; buf[0] = Self::$status; buf[1..5].copy_from_slice(&(encoded_node_len as u32).to_le_bytes()); - T::encode_node(node, &mut buf[5..]).map_err(invalid_data_io_error)?; + node.encode(&mut buf[5..]).map_err(invalid_data_io_error)?; $w.write_all(&buf[..encoded_len]).map(|_| encoded_len) } else { let mut buf = BytesMut::with_capacity(encoded_len); buf.put_u8(Self::$status); buf.put_u32_le(encoded_node_len as u32); - T::encode_node(node, &mut buf).map_err(invalid_data_io_error)?; + node.encode(&mut buf).map_err(invalid_data_io_error)?; $w.write_all(&buf).map(|_| encoded_len) } }}; @@ -188,8 +188,8 @@ macro_rules! encode { impl SnapshotRecord<'_, I, A> where - I: Id, - A: CheapClone + Send + Sync + 'static, + I: Id + Data, + A: CheapClone + Data + Send + Sync + 'static, { const ALIVE: u8 = 0; const NOT_ALIVE: u8 = 1; @@ -200,10 +200,7 @@ where const LEAVE: u8 = 6; const COMMENT: u8 = 7; - fn encode( - &self, - w: &mut W, - ) -> std::io::Result { + fn encode(&self, w: &mut W) -> std::io::Result { match self { Self::Alive(id) => encode!(w.id::ALIVE), Self::NotAlive(id) => encode!(w.id::NOT_ALIVE), @@ -229,8 +226,8 @@ pub(crate) struct ReplayResult { } pub(crate) fn open_and_replay_snapshot< - I: Id, - A: CheapClone + core::hash::Hash + Eq + Send + Sync + 'static, + I: Id + Data, + A: CheapClone + Data + core::hash::Hash + Eq + Send + Sync + 'static, P: AsRef, >( p: &P, @@ -284,8 +281,8 @@ pub(crate) fn open_and_replay_snapshot< buf.resize(len, 0); reader.read_exact(&mut buf).map_err(SnapshotError::Replay)?; - let (_, node) = - T::decode_node(&buf).map_err(|e| SnapshotError::Replay(invalid_data_io_error(e)))?; + let (_, node) = as Data>::decode(&buf) + .map_err(|e| SnapshotError::Replay(invalid_data_io_error(e)))?; alive_nodes.insert(node); } SnapshotRecordType::NotAlive => { @@ -295,8 +292,8 @@ pub(crate) fn open_and_replay_snapshot< buf.resize(len, 0); reader.read_exact(&mut buf).map_err(SnapshotError::Replay)?; - let (_, node) = - T::decode_node(&buf).map_err(|e| SnapshotError::Replay(invalid_data_io_error(e)))?; + let (_, node) = as Data>::decode(&buf) + .map_err(|e| SnapshotError::Replay(invalid_data_io_error(e)))?; alive_nodes.remove(&node); } SnapshotRecordType::Clock => { @@ -750,7 +747,7 @@ where ); let f = self.fh.as_mut().unwrap(); - let n = l.encode::(f).map_err(SnapshotError::Write)?; + let n = l.encode(f).map_err(SnapshotError::Write)?; // check if we should flush if self.last_flush.elapsed() > FLUSH_INTERVAL { @@ -818,21 +815,21 @@ where let mut offset = 0u64; for node in self.alive_nodes.iter() { offset += SnapshotRecord::Alive(Cow::Borrowed(node)) - .encode::(&mut buf) + .encode(&mut buf) .map_err(SnapshotError::WriteNew)? as u64; } // Write out the clocks - offset += SnapshotRecord::Clock(self.last_clock) - .encode::(&mut buf) + offset += SnapshotRecord::::Clock(self.last_clock) + .encode(&mut buf) .map_err(SnapshotError::WriteNew)? as u64; - offset += SnapshotRecord::EventClock(self.last_event_clock) - .encode::(&mut buf) + offset += SnapshotRecord::::EventClock(self.last_event_clock) + .encode(&mut buf) .map_err(SnapshotError::WriteNew)? as u64; - offset += SnapshotRecord::QueryClock(self.last_query_clock) - .encode::(&mut buf) + offset += SnapshotRecord::::QueryClock(self.last_query_clock) + .encode(&mut buf) .map_err(SnapshotError::WriteNew)? as u64; // Flush the new snapshot diff --git a/serf-core/src/types/member.rs b/serf-core/src/types/member.rs index 43b7b3a..6400a81 100644 --- a/serf-core/src/types/member.rs +++ b/serf-core/src/types/member.rs @@ -1,5 +1,5 @@ use memberlist_core::proto::OneOrMore; -use serf_proto::Member; +use serf_proto::{Member, MessageType}; use std::collections::HashMap; diff --git a/serf-proto/src/conflict.rs b/serf-proto/src/conflict.rs new file mode 100644 index 0000000..087959a --- /dev/null +++ b/serf-proto/src/conflict.rs @@ -0,0 +1,174 @@ +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, WireType, + utils::{merge, skip, split}, +}; + +use super::*; + +/// A conflict message +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct ConflictResponseMessage { + /// The member + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the member")), + setter(attrs(doc = "Sets the member (Builder pattern)")) + )] + member: Member, +} + +impl ConflictResponseMessage { + /// Create a new conflict response message + pub fn new(member: Member) -> Self { + Self { member } + } +} + +/// The borrow type of conflict message +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, PartialEq)] +pub struct ConflictResponseMessageBorrow<'a, I, A> { + /// The member + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the member")), + setter(attrs(doc = "Sets the member (Builder pattern)")) + )] + member: &'a Member, +} + +impl<'a, I, A> ConflictResponseMessageBorrow<'a, I, A> { + /// Create a new conflict response message + pub fn new(member: &'a Member) -> Self { + Self { member } + } +} + +impl<'a, I, A> From<&'a ConflictResponseMessage> for ConflictResponseMessageBorrow<'a, I, A> { + fn from(val: &'a ConflictResponseMessage) -> Self { + Self::new(&val.member) + } +} + +impl<'a, I, A> From<&'a Member> for ConflictResponseMessageBorrow<'a, I, A> { + fn from(val: &'a Member) -> Self { + Self::new(val) + } +} + +impl ConflictResponseMessageBorrow<'_, I, A> +where + I: Data, + A: Data, +{ + pub(super) fn encoded_len_in(&self) -> usize { + 1 + self.member.encoded_len_with_length_delimited() + } + + pub(super) fn encode_in(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + + if offset >= buf.len() { + return Err(EncodeError::insufficient_buffer( + self.encoded_len_in(), + buf.len(), + )); + } + + buf[offset] = MEMBER_BYTE; + offset += 1; + offset += self + .member + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len_in(), buf.len()))?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len_in()); + + Ok(offset) + } +} + +const MEMBER_TAG: u8 = 1; +const MEMBER_BYTE: u8 = merge(WireType::LengthDelimited, MEMBER_TAG); + +/// The reference to a [`ConflictResponseMessage`]. +#[viewit::viewit(getters(style = "ref", vis_all = "pub"), setters(skip), vis_all = "")] +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ConflictResponseMessageRef<'a, I, A> { + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the member")))] + member: MemberRef<'a, I, A>, +} + +impl<'a, I, A> DataRef<'a, ConflictResponseMessage> + for ConflictResponseMessageRef<'a, I::Ref<'a>, A::Ref<'a>> +where + I: Data, + A: Data, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let mut member = None; + + while offset < buf.len() { + match buf[offset] { + MEMBER_BYTE => { + if member.is_some() { + return Err(DecodeError::duplicate_field( + "ConflictResponseMessage", + "member", + MEMBER_TAG, + )); + } + offset += 1; + + let (len, val) = , A::Ref<'_>> as DataRef<'_, Member>>::decode_length_delimited(&buf[offset..])?; + offset += len; + member = Some(val); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + let member = member.ok_or(DecodeError::missing_field( + "ConflictResponseMessage", + "member", + ))?; + Ok((offset, Self { member })) + } +} + +impl Data for ConflictResponseMessage +where + I: Data, + A: Data, +{ + type Ref<'a> = ConflictResponseMessageRef<'a, I::Ref<'a>, A::Ref<'a>>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(Self { + member: Member::from_ref(val.member)?, + }) + } + + fn encoded_len(&self) -> usize { + ConflictResponseMessageBorrow::from(self).encoded_len() + } + + fn encode(&self, buf: &mut [u8]) -> Result { + ConflictResponseMessageBorrow::from(self).encode_in(buf) + } +} diff --git a/serf-proto/src/lib.rs b/serf-proto/src/lib.rs index 8333a6d..8c42dc2 100644 --- a/serf-proto/src/lib.rs +++ b/serf-proto/src/lib.rs @@ -17,6 +17,9 @@ mod arbitrary_impl; mod clock; pub use clock::*; +mod conflict; +pub use conflict::*; + mod filter; pub use filter::*; @@ -26,8 +29,8 @@ pub use leave::*; mod member; pub use member::*; -// mod message; -// pub use message::*; +mod message; +pub use message::*; mod join; pub use join::*; diff --git a/serf-proto/src/message.rs b/serf-proto/src/message.rs index b0bddec..0757c9e 100644 --- a/serf-proto/src/message.rs +++ b/serf-proto/src/message.rs @@ -1,370 +1,309 @@ -// use std::sync::Arc; - -// use super::{ -// JoinMessage, LeaveMessage, Member, PushPullMessage, PushPullMessageRef, QueryMessage, -// QueryResponseMessage, UserEventMessage, -// }; - -// #[cfg(feature = "encryption")] -// use super::{KeyRequestMessage, KeyResponseMessage}; - -// const LEAVE_MESSAGE_TAG: u8 = 0; -// const JOIN_MESSAGE_TAG: u8 = 1; -// const PUSH_PULL_MESSAGE_TAG: u8 = 2; -// const USER_EVENT_MESSAGE_TAG: u8 = 3; -// const QUERY_MESSAGE_TAG: u8 = 4; -// const QUERY_RESPONSE_MESSAGE_TAG: u8 = 5; -// const CONFLICT_RESPONSE_MESSAGE_TAG: u8 = 6; -// const RELAY_MESSAGE_TAG: u8 = 7; -// #[cfg(feature = "encryption")] -// const KEY_REQUEST_MESSAGE_TAG: u8 = 253; -// #[cfg(feature = "encryption")] -// const KEY_RESPONSE_MESSAGE_TAG: u8 = 254; - -// /// Unknown message type error -// #[derive(Debug, thiserror::Error)] -// #[error("unknown message type byte: {0}")] -// pub struct UnknownMessageType(u8); - -// impl TryFrom for MessageType { -// type Error = UnknownMessageType; - -// fn try_from(value: u8) -> Result { -// Ok(match value { -// LEAVE_MESSAGE_TAG => Self::Leave, -// JOIN_MESSAGE_TAG => Self::Join, -// PUSH_PULL_MESSAGE_TAG => Self::PushPull, -// USER_EVENT_MESSAGE_TAG => Self::UserEvent, -// QUERY_MESSAGE_TAG => Self::Query, -// QUERY_RESPONSE_MESSAGE_TAG => Self::QueryResponse, -// CONFLICT_RESPONSE_MESSAGE_TAG => Self::ConflictResponse, -// RELAY_MESSAGE_TAG => Self::Relay, -// #[cfg(feature = "encryption")] -// KEY_REQUEST_MESSAGE_TAG => Self::KeyRequest, -// #[cfg(feature = "encryption")] -// KEY_RESPONSE_MESSAGE_TAG => Self::KeyResponse, -// _ => return Err(UnknownMessageType(value)), -// }) -// } -// } - -// impl From for u8 { -// fn from(value: MessageType) -> Self { -// value as u8 -// } -// } - -// /// The types of gossip messages Serf will send along -// /// memberlist. -// #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -// #[repr(u8)] -// #[non_exhaustive] -// pub enum MessageType { -// /// Leave message -// Leave = LEAVE_MESSAGE_TAG, -// /// Join message -// Join = JOIN_MESSAGE_TAG, -// /// PushPull message -// PushPull = PUSH_PULL_MESSAGE_TAG, -// /// UserEvent message -// UserEvent = USER_EVENT_MESSAGE_TAG, -// /// Query message -// Query = QUERY_MESSAGE_TAG, -// /// QueryResponse message -// QueryResponse = QUERY_RESPONSE_MESSAGE_TAG, -// /// ConflictResponse message -// ConflictResponse = CONFLICT_RESPONSE_MESSAGE_TAG, -// /// Relay message -// Relay = RELAY_MESSAGE_TAG, -// /// KeyRequest message -// #[cfg(feature = "encryption")] -// KeyRequest = KEY_REQUEST_MESSAGE_TAG, -// /// KeyResponse message -// #[cfg(feature = "encryption")] -// KeyResponse = KEY_RESPONSE_MESSAGE_TAG, -// } - -// impl MessageType { -// /// Get the string representation of the message type -// #[inline] -// pub const fn as_str(&self) -> &'static str { -// match self { -// Self::Leave => "leave", -// Self::Join => "join", -// Self::PushPull => "push pull", -// Self::UserEvent => "user event", -// Self::Query => "query", -// Self::QueryResponse => "query response", -// Self::ConflictResponse => "conflict response", -// Self::Relay => "relay", -// #[cfg(feature = "encryption")] -// Self::KeyRequest => "key request", -// #[cfg(feature = "encryption")] -// Self::KeyResponse => "key response", -// } -// } -// } - -// /// Used to do a cheap reference to message reference conversion. -// pub trait AsMessageRef { -// /// Converts this type into a shared reference of the (usually inferred) input type. -// fn as_message_ref(&self) -> SerfMessageRef<'_, I, A>; -// } - -// /// The reference type of [`SerfMessage`]. -// #[derive(Debug)] -// pub enum SerfMessageRef<'a, I, A> { -// /// Leave message reference -// Leave(&'a LeaveMessage), -// /// Join message reference -// Join(&'a JoinMessage), -// /// PushPull message reference -// PushPull(PushPullMessageRef<'a, I>), -// /// UserEvent message reference -// UserEvent(&'a UserEventMessage), -// /// Query message reference -// Query(&'a QueryMessage), -// /// QueryResponse message reference -// QueryResponse(&'a QueryResponseMessage), -// /// ConflictResponse message reference -// ConflictResponse(&'a Member), -// /// KeyRequest message reference -// #[cfg(feature = "encryption")] -// KeyRequest(&'a KeyRequestMessage), -// /// KeyResponse message reference -// #[cfg(feature = "encryption")] -// KeyResponse(&'a KeyResponseMessage), -// } - -// impl Clone for SerfMessageRef<'_, I, A> { -// fn clone(&self) -> Self { -// *self -// } -// } - -// impl Copy for SerfMessageRef<'_, I, A> {} - -// impl AsMessageRef for SerfMessageRef<'_, I, A> { -// fn as_message_ref(&self) -> SerfMessageRef { -// *self -// } -// } - -// /// The types of gossip messages Serf will send along -// /// memberlist. -// #[derive(Debug, Clone)] -// pub enum SerfMessage { -// /// Leave message -// Leave(LeaveMessage), -// /// Join message -// Join(JoinMessage), -// /// PushPull message -// PushPull(PushPullMessage), -// /// UserEvent message -// UserEvent(UserEventMessage), -// /// Query message -// Query(QueryMessage), -// /// QueryResponse message -// QueryResponse(QueryResponseMessage), -// /// ConflictResponse message -// ConflictResponse(Member), -// /// Relay message -// #[cfg(feature = "encryption")] -// KeyRequest(KeyRequestMessage), -// /// KeyResponse message -// #[cfg(feature = "encryption")] -// KeyResponse(KeyResponseMessage), -// } - -// impl<'a, I, A> From<&'a SerfMessage> for MessageType { -// fn from(msg: &'a SerfMessage) -> Self { -// match msg { -// SerfMessage::Leave(_) => MessageType::Leave, -// SerfMessage::Join(_) => MessageType::Join, -// SerfMessage::PushPull(_) => MessageType::PushPull, -// SerfMessage::UserEvent(_) => MessageType::UserEvent, -// SerfMessage::Query(_) => MessageType::Query, -// SerfMessage::QueryResponse(_) => MessageType::QueryResponse, -// SerfMessage::ConflictResponse(_) => MessageType::ConflictResponse, -// #[cfg(feature = "encryption")] -// SerfMessage::KeyRequest(_) => MessageType::KeyRequest, -// #[cfg(feature = "encryption")] -// SerfMessage::KeyResponse(_) => MessageType::KeyResponse, -// } -// } -// } - -// impl AsMessageRef for QueryMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::Query(self) -// } -// } - -// impl AsMessageRef for QueryResponseMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::QueryResponse(self) -// } -// } - -// impl AsMessageRef for JoinMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::Join(self) -// } -// } - -// impl AsMessageRef for UserEventMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::UserEvent(self) -// } -// } - -// impl AsMessageRef for &QueryMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::Query(self) -// } -// } - -// impl AsMessageRef for &QueryResponseMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::QueryResponse(self) -// } -// } - -// impl AsMessageRef for &JoinMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::Join(self) -// } -// } - -// impl AsMessageRef for PushPullMessageRef<'_, I> { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::PushPull(*self) -// } -// } - -// impl AsMessageRef for &PushPullMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::PushPull(PushPullMessageRef { -// ltime: self.ltime, -// status_ltimes: &self.status_ltimes, -// left_members: &self.left_members, -// event_ltime: self.event_ltime, -// events: &self.events, -// query_ltime: self.query_ltime, -// }) -// } -// } - -// impl AsMessageRef for &UserEventMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::UserEvent(self) -// } -// } - -// impl AsMessageRef for &LeaveMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::Leave(self) -// } -// } - -// impl AsMessageRef for &Member { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::ConflictResponse(self) -// } -// } - -// impl AsMessageRef for &Arc> { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::ConflictResponse(self) -// } -// } - -// #[cfg(feature = "encryption")] -// impl AsMessageRef for &KeyRequestMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::KeyRequest(self) -// } -// } - -// #[cfg(feature = "encryption")] -// impl AsMessageRef for &KeyResponseMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// SerfMessageRef::KeyResponse(self) -// } -// } - -// impl AsMessageRef for SerfMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// match self { -// Self::Leave(l) => SerfMessageRef::Leave(l), -// Self::Join(j) => SerfMessageRef::Join(j), -// Self::PushPull(pp) => SerfMessageRef::PushPull(PushPullMessageRef { -// ltime: pp.ltime, -// status_ltimes: &pp.status_ltimes, -// left_members: &pp.left_members, -// event_ltime: pp.event_ltime, -// events: &pp.events, -// query_ltime: pp.query_ltime, -// }), -// Self::UserEvent(u) => SerfMessageRef::UserEvent(u), -// Self::Query(q) => SerfMessageRef::Query(q), -// Self::QueryResponse(q) => SerfMessageRef::QueryResponse(q), -// Self::ConflictResponse(m) => SerfMessageRef::ConflictResponse(m), -// #[cfg(feature = "encryption")] -// Self::KeyRequest(kr) => SerfMessageRef::KeyRequest(kr), -// #[cfg(feature = "encryption")] -// Self::KeyResponse(kr) => SerfMessageRef::KeyResponse(kr), -// } -// } -// } - -// impl AsMessageRef for &SerfMessage { -// fn as_message_ref(&self) -> SerfMessageRef { -// match self { -// SerfMessage::Leave(l) => SerfMessageRef::Leave(l), -// SerfMessage::Join(j) => SerfMessageRef::Join(j), -// SerfMessage::PushPull(pp) => SerfMessageRef::PushPull(PushPullMessageRef { -// ltime: pp.ltime, -// status_ltimes: &pp.status_ltimes, -// left_members: &pp.left_members, -// event_ltime: pp.event_ltime, -// events: &pp.events, -// query_ltime: pp.query_ltime, -// }), -// SerfMessage::UserEvent(u) => SerfMessageRef::UserEvent(u), -// SerfMessage::Query(q) => SerfMessageRef::Query(q), -// SerfMessage::QueryResponse(q) => SerfMessageRef::QueryResponse(q), -// SerfMessage::ConflictResponse(m) => SerfMessageRef::ConflictResponse(m), -// #[cfg(feature = "encryption")] -// SerfMessage::KeyRequest(kr) => SerfMessageRef::KeyRequest(kr), -// #[cfg(feature = "encryption")] -// SerfMessage::KeyResponse(kr) => SerfMessageRef::KeyResponse(kr), -// } -// } -// } - -// impl core::fmt::Display for SerfMessage { -// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { -// write!(f, "{}", self.ty().as_str()) -// } -// } - -// impl SerfMessage { -// /// Returns the message type of this message -// #[inline] -// pub const fn ty(&self) -> MessageType { -// match self { -// Self::Leave(_) => MessageType::Leave, -// Self::Join(_) => MessageType::Join, -// Self::PushPull(_) => MessageType::PushPull, -// Self::UserEvent(_) => MessageType::UserEvent, -// Self::Query(_) => MessageType::Query, -// Self::QueryResponse(_) => MessageType::QueryResponse, -// Self::ConflictResponse(_) => MessageType::ConflictResponse, -// #[cfg(feature = "encryption")] -// Self::KeyRequest(_) => MessageType::KeyRequest, -// #[cfg(feature = "encryption")] -// Self::KeyResponse(_) => MessageType::KeyResponse, -// } -// } -// } +use memberlist_proto::{Data, EncodeError, WireType, bytes::Bytes, utils::merge}; + +use super::{ + ConflictResponseMessage, ConflictResponseMessageBorrow, JoinMessage, LeaveMessage, + PushPullMessageBorrow, QueryMessage, QueryResponseMessage, UserEventMessage, +}; + +#[cfg(feature = "encryption")] +use super::{KeyRequestMessage, KeyResponseMessage}; + +const LEAVE_MESSAGE_TAG: u8 = 1; +const JOIN_MESSAGE_TAG: u8 = 2; +const PUSH_PULL_MESSAGE_TAG: u8 = 3; +const USER_EVENT_MESSAGE_TAG: u8 = 4; +const QUERY_MESSAGE_TAG: u8 = 5; +const QUERY_RESPONSE_MESSAGE_TAG: u8 = 6; +const CONFLICT_RESPONSE_MESSAGE_TAG: u8 = 7; +const RELAY_MESSAGE_TAG: u8 = 8; +#[cfg(feature = "encryption")] +const KEY_REQUEST_MESSAGE_TAG: u8 = 9; +#[cfg(feature = "encryption")] +const KEY_RESPONSE_MESSAGE_TAG: u8 = 10; + +const LEAVE_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, LEAVE_MESSAGE_TAG); +const JOIN_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, JOIN_MESSAGE_TAG); +const PUSH_PULL_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, PUSH_PULL_MESSAGE_TAG); +const USER_EVENT_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, USER_EVENT_MESSAGE_TAG); +const QUERY_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, QUERY_MESSAGE_TAG); +const QUERY_RESPONSE_MESSAGE_BYTE: u8 = + merge(WireType::LengthDelimited, QUERY_RESPONSE_MESSAGE_TAG); +const CONFLICT_RESPONSE_MESSAGE_BYTE: u8 = + merge(WireType::LengthDelimited, CONFLICT_RESPONSE_MESSAGE_TAG); +const RELAY_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, RELAY_MESSAGE_TAG); +#[cfg(feature = "encryption")] +const KEY_REQUEST_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, KEY_REQUEST_MESSAGE_TAG); +#[cfg(feature = "encryption")] +const KEY_RESPONSE_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, KEY_RESPONSE_MESSAGE_TAG); + +/// The types of gossip messages Serf will send along +/// memberlist. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(u8)] +#[non_exhaustive] +pub enum MessageType { + /// Leave message + Leave, + /// Join message + Join, + /// PushPull message + PushPull, + /// UserEvent message + UserEvent, + /// Query message + Query, + /// QueryResponse message + QueryResponse, + /// ConflictResponse message + ConflictResponse, + /// Relay message + Relay, + /// KeyRequest message + #[cfg(feature = "encryption")] + KeyRequest, + /// KeyResponse message + #[cfg(feature = "encryption")] + KeyResponse, + /// Unknown message type, used for forwards and backwards compatibility + Unknown(u8), +} + +impl MessageType { + /// Get the string representation of the message type + #[inline] + pub fn as_str(&self) -> std::borrow::Cow<'static, str> { + std::borrow::Cow::Borrowed(match self { + Self::Leave => "leave", + Self::Join => "join", + Self::PushPull => "push_pull", + Self::UserEvent => "user_event", + Self::Query => "query", + Self::QueryResponse => "query_response", + Self::ConflictResponse => "conflict_response", + Self::Relay => "relay", + #[cfg(feature = "encryption")] + Self::KeyRequest => "key_request", + #[cfg(feature = "encryption")] + Self::KeyResponse => "key_response", + Self::Unknown(val) => return std::borrow::Cow::Owned(format!("unknown({val})")), + }) + } +} + +impl From for MessageType { + fn from(value: u8) -> Self { + match value { + LEAVE_MESSAGE_TAG => Self::Leave, + JOIN_MESSAGE_TAG => Self::Join, + PUSH_PULL_MESSAGE_TAG => Self::PushPull, + USER_EVENT_MESSAGE_TAG => Self::UserEvent, + QUERY_MESSAGE_TAG => Self::Query, + QUERY_RESPONSE_MESSAGE_TAG => Self::QueryResponse, + CONFLICT_RESPONSE_MESSAGE_TAG => Self::ConflictResponse, + RELAY_MESSAGE_TAG => Self::Relay, + #[cfg(feature = "encryption")] + KEY_REQUEST_MESSAGE_TAG => Self::KeyRequest, + #[cfg(feature = "encryption")] + KEY_RESPONSE_MESSAGE_TAG => Self::KeyResponse, + val => Self::Unknown(val), + } + } +} + +impl From for u8 { + fn from(val: MessageType) -> Self { + match val { + MessageType::Leave => LEAVE_MESSAGE_TAG, + MessageType::Join => JOIN_MESSAGE_TAG, + MessageType::PushPull => PUSH_PULL_MESSAGE_TAG, + MessageType::UserEvent => USER_EVENT_MESSAGE_TAG, + MessageType::Query => QUERY_MESSAGE_TAG, + MessageType::QueryResponse => QUERY_RESPONSE_MESSAGE_TAG, + MessageType::ConflictResponse => CONFLICT_RESPONSE_MESSAGE_TAG, + MessageType::Relay => RELAY_MESSAGE_TAG, + #[cfg(feature = "encryption")] + MessageType::KeyRequest => KEY_REQUEST_MESSAGE_TAG, + #[cfg(feature = "encryption")] + MessageType::KeyResponse => KEY_RESPONSE_MESSAGE_TAG, + MessageType::Unknown(val) => val, + } + } +} + +macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer( + Encodable::encoded_len($this), + $len, + )); + } + }; +} + +/// A trait for encoding messages. +pub trait Encodable { + /// Encodes the message into a buffer. + fn encode(&self, buf: &mut [u8]) -> Result; + + /// Encodes a relay message into a buffer. + fn encode_relay(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + let buf_len = buf.len(); + + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len_with_relay(), + buf_len, + )); + } + + buf[offset] = RELAY_MESSAGE_BYTE; + offset += 1; + + offset += self + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len_with_relay(), buf_len))?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len_with_relay()); + + Ok(offset) + } + + /// Encodes the message into a [`Bytes`]. + fn encode_to_bytes(&self) -> Result { + let len = self.encoded_len(); + let mut buf = vec![0; len]; + self.encode(&mut buf).map(|_| Bytes::from(buf)) + } + + /// Encodes a relay message into a [`Bytes`]. + fn encode_relay_to_bytes(&self) -> Result { + let len = self.encoded_len_with_relay(); + let mut buf = vec![0; len]; + self.encode_relay(&mut buf).map(|_| Bytes::from(buf)) + } + + /// Returns the encoded length of the message. + fn encoded_len(&self) -> usize; + + /// Returns the encoded length of the message with a relay tag. + fn encoded_len_with_relay(&self) -> usize { + 1 + self.encoded_len() + } +} + +impl Encodable for &T { + fn encode(&self, buf: &mut [u8]) -> Result { + (*self).encode(buf) + } + + fn encoded_len(&self) -> usize { + (*self).encoded_len() + } +} + +macro_rules! impl_encodable { + ( + $( + $(#[$attr:meta])* + $type:ident $(<$($generic:ident), +$(,)?>)? = $id:expr, + )* + ) => { + $( + $(#[$attr])* + impl $(<$($generic), +>)? Encodable for $type $(<$($generic), +>)? + $( + where + $($generic: Data,)+ + )? + { + fn encode(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + let buf_len = buf.len(); + bail!(self(offset, buf_len)); + + buf[offset] = $id; + offset += 1; + + offset += self.encode_length_delimited(&mut buf[offset..])?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, Encodable::encoded_len(self)); + + Ok(offset) + } + + fn encoded_len(&self) -> usize { + 1 + self.encoded_len_with_length_delimited() + } + } + )* + }; +} + +impl_encodable!( + LeaveMessage = LEAVE_MESSAGE_BYTE, + JoinMessage = JOIN_MESSAGE_BYTE, + // PushPullMessage = PUSH_PULL_MESSAGE_BYTE, + UserEventMessage = USER_EVENT_MESSAGE_BYTE, + QueryMessage = QUERY_MESSAGE_BYTE, + QueryResponseMessage = QUERY_RESPONSE_MESSAGE_BYTE, + ConflictResponseMessage = CONFLICT_RESPONSE_MESSAGE_BYTE, + #[cfg(feature = "encryption")] + KeyRequestMessage = KEY_REQUEST_MESSAGE_BYTE, + #[cfg(feature = "encryption")] + KeyResponseMessage = KEY_RESPONSE_MESSAGE_BYTE, +); + +impl super::Encodable for ConflictResponseMessageBorrow<'_, I, A> +where + I: Data, + A: Data, +{ + fn encode(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + let buf_len = buf.len(); + bail!(self(offset, buf_len)); + + buf[offset] = CONFLICT_RESPONSE_MESSAGE_BYTE; + offset += 1; + + offset += self.encode_in(&mut buf[offset..])?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, Encodable::encoded_len(self)); + + Ok(offset) + } + + fn encoded_len(&self) -> usize { + 1 + self.encoded_len_in() + } +} + +impl super::Encodable for PushPullMessageBorrow<'_, I> +where + I: Data, +{ + fn encode(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + let buf_len = buf.len(); + bail!(self(offset, buf_len)); + + buf[offset] = PUSH_PULL_MESSAGE_BYTE; + offset += 1; + + offset += self.encode_in(&mut buf[offset..])?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, Encodable::encoded_len(self)); + + Ok(offset) + } + + fn encoded_len(&self) -> usize { + 1 + self.encoded_len_in() + } +} diff --git a/serf-proto/src/push_pull.rs b/serf-proto/src/push_pull.rs index d676ebd..d748bfd 100644 --- a/serf-proto/src/push_pull.rs +++ b/serf-proto/src/push_pull.rs @@ -382,7 +382,7 @@ where macro_rules! bail { ($this:ident($offset:expr, $len:ident)) => { if $offset >= $len { - return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + return Err(EncodeError::insufficient_buffer($this.encoded_len(), $len)); } }; } @@ -455,3 +455,148 @@ where Ok(offset) } } + +/// Used when doing a state exchange. This +/// is a relatively large message, but is sent infrequently +#[viewit::viewit(getters(skip), setters(skip))] +#[derive(Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct PushPullMessageBorrow<'a, I> { + /// Current node lamport time + ltime: LamportTime, + /// Maps the node to its status time + status_ltimes: &'a IndexMap, + /// List of left nodes + left_members: &'a IndexSet, + /// Lamport time for event clock + event_ltime: LamportTime, + /// Recent events + events: &'a [Option], + /// Lamport time for query clock + query_ltime: LamportTime, +} + +impl Clone for PushPullMessageBorrow<'_, I> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for PushPullMessageBorrow<'_, I> {} + +impl PushPullMessageBorrow<'_, I> +where + I: Data, +{ + pub(super) fn encoded_len_in(&self) -> usize { + let mut len = 0usize; + + len += 1 + self.ltime.encoded_len(); + + len += self + .status_ltimes + .iter() + .map(|(k, v)| 1 + TupleEncoder::new(k, v).encoded_len_with_length_delimited()) + .sum::(); + + len += self + .left_members + .iter() + .map(|id| 1 + id.encoded_len_with_length_delimited()) + .sum::(); + len += 1 + self.event_ltime.encoded_len(); + len += 1 + + self + .events + .iter() + .filter_map(|e| { + e.as_ref() + .map(|e| 1 + e.encoded_len_with_length_delimited()) + }) + .sum::(); + len += 1 + self.query_ltime.encoded_len(); + + len + } + + pub(super) fn encode_in(&self, buf: &mut [u8]) -> Result { + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer( + $this.encoded_len_in(), + $len, + )); + } + }; + } + + let mut offset = 0; + let buf_len = buf.len(); + + bail!(self(offset, buf_len)); + buf[offset] = LTIME_BYTE; + offset += 1; + offset += self.ltime.encode(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = STATUS_LTIMES_BYTE; + offset += 1; + + self + .status_ltimes + .iter() + .try_fold(&mut offset, |off, (k, v)| { + bail!(self(*off, buf_len)); + buf[*off] = LEFT_MEMBERS_BYTE; + *off += 1; + *off += TupleEncoder::new(k, v).encode_with_length_delimited(&mut buf[*off..])?; + Ok(off) + }) + .map_err(|e: EncodeError| e.update(self.encoded_len_in(), buf_len))?; + + self + .left_members + .iter() + .try_fold(&mut offset, |off, id| { + bail!(self(*off, buf_len)); + buf[*off] = LEFT_MEMBERS_BYTE; + *off += 1; + *off += id.encode_length_delimited(&mut buf[*off..])?; + Ok(off) + }) + .map_err(|e: EncodeError| e.update(self.encoded_len_in(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = EVENT_LTIME_BYTE; + offset += 1; + offset += self.event_ltime.encode(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = EVENTS_BYTE; + offset += 1; + + self + .events + .iter() + .filter_map(|e| e.as_ref()) + .try_fold(&mut offset, |off, e| { + bail!(self(*off, buf_len)); + buf[*off] = EVENTS_BYTE; + *off += 1; + *off += e.encode_length_delimited(&mut buf[*off..])?; + Ok(off) + }) + .map_err(|e: EncodeError| e.update(self.encoded_len_in(), buf_len))?; + + bail!(self(offset, buf_len)); + buf[offset] = QUERY_LTIME_BYTE; + offset += 1; + offset += self.query_ltime.encode(&mut buf[offset..])?; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len_in()); + + Ok(offset) + } +} diff --git a/serf/test/main.rs b/serf/test/main.rs index c8ed71e..bde3dd0 100644 --- a/serf/test/main.rs +++ b/serf/test/main.rs @@ -1,28 +1,28 @@ -use core::future::Future; -use serf_core::tests::run as run_unit_test; +// use core::future::Future; +// use serf_core::tests::run as run_unit_test; -#[cfg(feature = "net")] -#[path = "./main/net.rs"] -mod net; +// #[cfg(feature = "net")] +// #[path = "./main/net.rs"] +// mod net; -#[cfg(feature = "tokio")] -fn tokio_run(fut: impl Future) { - let runtime = ::tokio::runtime::Builder::new_multi_thread() - .worker_threads(32) - .enable_all() - .build() - .unwrap(); - run_unit_test(|fut| runtime.block_on(fut), fut) -} +// #[cfg(feature = "tokio")] +// fn tokio_run(fut: impl Future) { +// let runtime = ::tokio::runtime::Builder::new_multi_thread() +// .worker_threads(32) +// .enable_all() +// .build() +// .unwrap(); +// run_unit_test(|fut| runtime.block_on(fut), fut) +// } -#[cfg(feature = "smol")] -fn smol_run(fut: impl Future) { - use serf::agnostic::{RuntimeLite, smol::SmolRuntime}; - run_unit_test(SmolRuntime::block_on, fut); -} +// #[cfg(feature = "smol")] +// fn smol_run(fut: impl Future) { +// use serf::agnostic::{RuntimeLite, smol::SmolRuntime}; +// run_unit_test(SmolRuntime::block_on, fut); +// } -#[cfg(feature = "async-std")] -fn async_std_run(fut: impl Future) { - use serf::agnostic::{RuntimeLite, async_std::AsyncStdRuntime}; - run_unit_test(AsyncStdRuntime::block_on, fut); -} +// #[cfg(feature = "async-std")] +// fn async_std_run(fut: impl Future) { +// use serf::agnostic::{RuntimeLite, async_std::AsyncStdRuntime}; +// run_unit_test(AsyncStdRuntime::block_on, fut); +// } From 02ee79d3c12f3614949a29e68144bdce62e56f5a Mon Sep 17 00:00:00 2001 From: al8n Date: Fri, 28 Feb 2025 20:41:07 +0800 Subject: [PATCH 07/39] WIP --- serf-core/src/coalesce/member.rs | 18 +- serf-core/src/coalesce/user.rs | 2 +- serf-core/src/coordinate.rs | 43 +- serf-core/src/delegate/composite.rs | 13 +- serf-core/src/delegate/merge.rs | 8 +- serf-core/src/error.rs | 122 ++++- serf-core/src/event.rs | 7 +- serf-core/src/event/crate_event.rs | 48 +- serf-core/src/key_manager.rs | 79 +-- serf-core/src/serf.rs | 32 +- serf-core/src/serf/api.rs | 19 +- serf-core/src/serf/base.rs | 332 +++++++----- serf-core/src/serf/base/tests.rs | 10 +- serf-core/src/serf/base/tests/serf.rs | 4 +- serf-core/src/serf/base/tests/serf/event.rs | 6 +- serf-core/src/serf/base/tests/serf/join.rs | 10 +- serf-core/src/serf/base/tests/serf/leave.rs | 6 +- serf-core/src/serf/base/tests/serf/reap.rs | 4 +- .../src/serf/base/tests/serf/reconnect.rs | 4 +- .../src/serf/base/tests/serf/snapshot.rs | 10 +- serf-core/src/serf/delegate.rs | 504 +++++++++--------- serf-core/src/serf/internal_query.rs | 42 +- serf-core/src/serf/query.rs | 153 +++--- serf-core/src/snapshot.rs | 22 +- serf-proto/src/filter.rs | 125 +++-- serf-proto/src/filter/id_filter.rs | 63 --- serf-proto/src/key.rs | 18 +- serf-proto/src/message.rs | 402 +++++++++++++- serf-proto/src/query.rs | 18 +- serf-proto/src/query/response.rs | 12 +- serf-proto/src/tags.rs | 210 -------- serf-proto/src/user_event/message.rs | 16 +- 32 files changed, 1375 insertions(+), 987 deletions(-) delete mode 100644 serf-proto/src/filter/id_filter.rs diff --git a/serf-core/src/coalesce/member.rs b/serf-core/src/coalesce/member.rs index 4753326..6fe9446 100644 --- a/serf-core/src/coalesce/member.rs +++ b/serf-core/src/coalesce/member.rs @@ -4,7 +4,7 @@ use async_channel::Sender; use memberlist_core::{ CheapClone, proto::TinyVec, - transport::{AddressResolver, Node, Transport}, + transport::{Node, Transport}, }; use crate::{ @@ -22,12 +22,8 @@ pub(crate) struct CoalesceEvent { #[derive(Default)] pub(crate) struct MemberEventCoalescer { - last_events: - HashMap::ResolvedAddress>, MemberEventType>, - latest_events: HashMap< - Node::ResolvedAddress>, - CoalesceEvent::ResolvedAddress>, - >, + last_events: HashMap, MemberEventType>, + latest_events: HashMap, CoalesceEvent>, _m: PhantomData, } @@ -43,7 +39,7 @@ impl MemberEventCoalescer { impl Coalescer for MemberEventCoalescer where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Delegate = D; @@ -78,10 +74,8 @@ where &mut self, out_tx: &Sender>, ) -> Result<(), super::ClosedOutChannel> { - let mut events: HashMap< - MemberEventType, - MemberEventMut::ResolvedAddress>, - > = HashMap::with_capacity(self.latest_events.len()); + let mut events: HashMap> = + HashMap::with_capacity(self.latest_events.len()); // Coalesce the various events we got into a single set of events. for (id, cev) in self.latest_events.drain() { match self.last_events.get(&id) { diff --git a/serf-core/src/coalesce/user.rs b/serf-core/src/coalesce/user.rs index 8cf0ed6..7a31dc8 100644 --- a/serf-core/src/coalesce/user.rs +++ b/serf-core/src/coalesce/user.rs @@ -32,7 +32,7 @@ impl UserEventCoalescer { impl Coalescer for UserEventCoalescer where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Delegate = D; diff --git a/serf-core/src/coordinate.rs b/serf-core/src/coordinate.rs index 6a8a755..165bee8 100644 --- a/serf-core/src/coordinate.rs +++ b/serf-core/src/coordinate.rs @@ -4,7 +4,10 @@ use std::{ time::Duration, }; -use memberlist_core::CheapClone; +use memberlist_core::{ + CheapClone, + proto::{Data, DataRef, DecodeError, EncodeError, RepeatedDecoder}, +}; use parking_lot::RwLock; use rand::Rng; use smallvec::SmallVec; @@ -723,6 +726,44 @@ fn rand_f64() -> f64 { } } +/// The reference type to [`Coordinate`]. +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct CoordinateRef<'a> { + portion: RepeatedDecoder<'a>, + error: f64, + adjustment: f64, + height: f64, +} + +impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + todo!() + } +} + +impl Data for Coordinate { + type Ref<'a> = CoordinateRef<'a>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + // Ok(val) + todo!() + } + + fn encoded_len(&self) -> usize { + todo!() + } + + fn encode(&self, buf: &mut [u8]) -> Result { + todo!() + } +} + #[cfg(test)] mod tests { use smol_str::SmolStr; diff --git a/serf-core/src/delegate/composite.rs b/serf-core/src/delegate/composite.rs index 40ae2e8..9d74bc0 100644 --- a/serf-core/src/delegate/composite.rs +++ b/serf-core/src/delegate/composite.rs @@ -1,16 +1,13 @@ -use memberlist_core::{ - CheapClone, - proto::TinyVec, - transport::{Id, Node}, -}; -use serf_proto::MessageType; +use memberlist_core::{CheapClone, transport::Id}; -use crate::{coordinate::Coordinate, types::Member}; +use crate::types::Member; use super::{ DefaultMergeDelegate, Delegate, MergeDelegate, NoopReconnectDelegate, ReconnectDelegate, }; +use std::sync::Arc; + /// `CompositeDelegate` is a helpful struct to split the [`Delegate`] into multiple small delegates, /// so that users do not need to implement full [`Delegate`] when they only want to custom some methods /// in the [`Delegate`]. @@ -78,7 +75,7 @@ where async fn notify_merge( &self, - members: TinyVec>, + members: Arc<[Member]>, ) -> Result<(), Self::Error> { self.merge.notify_merge(members).await } diff --git a/serf-core/src/delegate/merge.rs b/serf-core/src/delegate/merge.rs index d1728fb..7190ffa 100644 --- a/serf-core/src/delegate/merge.rs +++ b/serf-core/src/delegate/merge.rs @@ -1,5 +1,5 @@ -use memberlist_core::{CheapClone, proto::TinyVec, transport::Id}; -use std::future::Future; +use memberlist_core::{CheapClone, transport::Id}; +use std::{future::Future, sync::Arc}; use crate::types::Member; @@ -23,7 +23,7 @@ pub trait MergeDelegate: Send + Sync + 'static { /// the return value is `Err`, the merge is canceled. fn notify_merge( &self, - members: TinyVec>, + members: Arc<[Member]>, ) -> impl Future> + Send; } @@ -48,7 +48,7 @@ where async fn notify_merge( &self, - _members: TinyVec>, + _members: Arc<[Member]>, ) -> Result<(), Self::Error> { Ok(()) } diff --git a/serf-core/src/error.rs b/serf-core/src/error.rs index 5cc683e..8de090c 100644 --- a/serf-core/src/error.rs +++ b/serf-core/src/error.rs @@ -1,17 +1,66 @@ use std::sync::Arc; -use memberlist_core::{proto::TinyVec, transport::Transport}; +use memberlist_core::{ + delegate::DelegateError as MemberlistDelegateError, proto::TinyVec, transport::Transport, +}; use crate::{ - delegate::Delegate, + delegate::{Delegate, MergeDelegate}, serf::{SerfDelegate, SerfState}, types::Member, }; pub use crate::snapshot::SnapshotError; +/// Error trait for [`Delegate`] +#[derive(thiserror::Error)] +pub enum SerfDelegateError { + /// Serf error + #[error(transparent)] + Serf(#[from] SerfError), + /// [`MergeDelegate`] error + #[error(transparent)] + Merge(::Error), +} + +impl core::fmt::Debug for SerfDelegateError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Merge(err) => write!(f, "{err:?}"), + Self::Serf(err) => write!(f, "{err:?}"), + } + } +} + +impl SerfDelegateError { + /// Create a delegate error from a merge delegate error. + #[inline] + pub const fn merge(err: ::Error) -> Self { + Self::Merge(err) + } + + /// Create a delegate error from a serf error. + #[inline] + pub const fn serf(err: crate::error::SerfError) -> Self { + Self::Serf(err) + } +} + +impl From>> for SerfDelegateError +where + D: Delegate, + T: Transport, +{ + fn from(value: MemberlistDelegateError>) -> Self { + match value { + MemberlistDelegateError::AliveDelegate(e) => e, + MemberlistDelegateError::MergeDelegate(e) => e, + } + } +} + /// Error type for the serf crate. -#[derive(Debug, thiserror::Error)] +#[derive(thiserror::Error)] pub enum Error where D: Delegate, @@ -31,6 +80,21 @@ where Multiple(Arc<[Self]>), } +impl core::fmt::Debug for Error +where + D: Delegate, + T: Transport, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Memberlist(e) => write!(f, "{e:?}"), + Self::Serf(e) => write!(f, "{e:?}"), + Self::Relay(e) => write!(f, "{e:?}"), + Self::Multiple(e) => write!(f, "{e:?}"), + } + } +} + impl From for Error where D: Delegate, @@ -47,7 +111,17 @@ where T: Transport, { fn from(e: memberlist_core::proto::EncodeError) -> Self { - Self::Memberlist(e.into()) + Self::Serf(e.into()) + } +} + +impl From for Error +where + D: Delegate, + T: Transport, +{ + fn from(e: memberlist_core::proto::DecodeError) -> Self { + Self::Serf(e.into()) } } @@ -169,28 +243,28 @@ where #[derive(Debug, thiserror::Error)] pub enum SerfError { /// Returned when the user event exceeds the configured limit. - #[error("serf: user event exceeds configured limit of {0} bytes before encoding")] + #[error("user event exceeds configured limit of {0} bytes before encoding")] UserEventLimitTooLarge(usize), /// Returned when the user event exceeds the sane limit. - #[error("serf: user event exceeds sane limit of {0} bytes before encoding")] + #[error("user event exceeds sane limit of {0} bytes before encoding")] UserEventTooLarge(usize), /// Returned when the join status is bad. - #[error("serf: join called on {0} statues")] + #[error("join called on {0} statues")] BadJoinStatus(SerfState), /// Returned when the leave status is bad. - #[error("serf: leave called on {0} statues")] + #[error("leave called on {0} statues")] BadLeaveStatus(SerfState), /// Returned when the encoded user event exceeds the sane limit after encoding. - #[error("serf: user event exceeds sane limit of {0} bytes after encoding")] + #[error("user event exceeds sane limit of {0} bytes after encoding")] RawUserEventTooLarge(usize), /// Returned when the query size exceeds the configured limit. - #[error("serf: query exceeds limit of {0} bytes")] + #[error("query exceeds limit of {0} bytes")] QueryTooLarge(usize), /// Returned when the query is timeout. - #[error("serf: query response is past the deadline")] + #[error("query response is past the deadline")] QueryTimeout, /// Returned when the query response is too large. - #[error("serf: query response ({got} bytes) exceeds limit of {limit} bytes")] + #[error("query response ({got} bytes) exceeds limit of {limit} bytes")] QueryResponseTooLarge { /// The query response size limit. limit: usize, @@ -198,31 +272,37 @@ pub enum SerfError { got: usize, }, /// Returned when the query has already been responded. - #[error("serf: query response already sent")] + #[error("query response already sent")] QueryAlreadyResponsed, /// Returned when failed to truncate response so that it fits into message. - #[error("serf: failed to truncate response so that it fits into message")] + #[error("failed to truncate response so that it fits into message")] FailTruncateResponse, /// Returned when the tags too large. - #[error("serf: encoded length of tags exceeds limit of {0} bytes")] + #[error("encoded length of tags exceeds limit of {0} bytes")] TagsTooLarge(usize), /// Returned when the relayed response is too large. - #[error("serf: relayed response exceeds limit of {0} bytes")] + #[error("relayed response exceeds limit of {0} bytes")] RelayedResponseTooLarge(usize), /// Returned when failed to deliver query response, dropping. - #[error("serf: failed to deliver query response, dropping")] + #[error("failed to deliver query response, dropping")] QueryResponseDeliveryFailed, /// Returned when the coordinates are disabled. - #[error("serf: coordinates are disabled")] + #[error("coordinates are disabled")] CoordinatesDisabled, /// Returned when snapshot error. - #[error("serf: {0}")] + #[error(transparent)] Snapshot(#[from] SnapshotError), + /// Returned when trying to decode a serf data + #[error(transparent)] + Decode(#[from] memberlist_core::proto::DecodeError), + /// Returned when trying to encode a serf data + #[error(transparent)] + Encode(#[from] memberlist_core::proto::EncodeError), /// Returned when timed out broadcasting node removal. - #[error("serf: timed out broadcasting node removal")] + #[error("timed out broadcasting node removal")] RemovalBroadcastTimeout, /// Returned when the timed out broadcasting channel closed. - #[error("serf: timed out broadcasting channel closed")] + #[error("timed out broadcasting channel closed")] BroadcastChannelClosed, } diff --git a/serf-core/src/event.rs b/serf-core/src/event.rs index 0458f8b..962d250 100644 --- a/serf-core/src/event.rs +++ b/serf-core/src/event.rs @@ -12,12 +12,7 @@ pub use async_channel::{RecvError, TryRecvError}; use async_lock::Mutex; pub(crate) use crate_event::*; use futures::Stream; -use memberlist_core::{ - CheapClone, - bytes::Bytes, - proto::TinyVec, - transport::{AddressResolver, Transport}, -}; +use memberlist_core::{CheapClone, bytes::Bytes, proto::TinyVec, transport::Transport}; use serf_proto::{LamportTime, Member, Node, QueryFlag, QueryResponseMessage, UserEventMessage}; use smol_str::SmolStr; diff --git a/serf-core/src/event/crate_event.rs b/serf-core/src/event/crate_event.rs index 1423abf..bf9444f 100644 --- a/serf-core/src/event/crate_event.rs +++ b/serf-core/src/event/crate_event.rs @@ -1,5 +1,5 @@ use memberlist_core::proto::{Data, DecodeError}; -use serf_proto::QueryMessage; +use serf_proto::{QueryMessage, QueryMessageRef}; use super::*; @@ -32,6 +32,31 @@ where } } +impl<'a, I, A> QueryMessageExt for QueryMessageRef<'a, I::Ref<'a>, A> +where + I: Data, +{ + fn decode_internal_query(&self) -> Option, DecodeError>> { + Some(Ok(match self.name() { + INTERNAL_PING => InternalQueryEvent::Ping, + INTERNAL_CONFLICT => { + return Some( + ::decode(self.payload()).map(|(_, id)| InternalQueryEvent::Conflict(id)), + ); + } + #[cfg(feature = "encryption")] + INTERNAL_INSTALL_KEY => InternalQueryEvent::InstallKey, + #[cfg(feature = "encryption")] + INTERNAL_USE_KEY => InternalQueryEvent::UseKey, + #[cfg(feature = "encryption")] + INTERNAL_REMOVE_KEY => InternalQueryEvent::RemoveKey, + #[cfg(feature = "encryption")] + INTERNAL_LIST_KEYS => InternalQueryEvent::ListKey, + _ => return None, + })) + } +} + const INTERNAL_PING: &str = "_serf_ping"; const INTERNAL_CONFLICT: &str = "_serf_conflict"; #[cfg(feature = "encryption")] @@ -56,10 +81,10 @@ pub enum CrateEventType { pub(crate) enum CrateEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { - Member(MemberEvent::ResolvedAddress>), + Member(MemberEvent), User(UserEventMessage), Query(QueryEvent), InternalQuery { @@ -70,7 +95,7 @@ where impl Clone for CrateEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn clone(&self) -> Self { @@ -88,7 +113,7 @@ where impl CrateEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Returns the type of the event @@ -108,20 +133,19 @@ where } } -impl From::ResolvedAddress>> - for CrateEvent +impl From> for CrateEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { - fn from(value: MemberEvent::ResolvedAddress>) -> Self { + fn from(value: MemberEvent) -> Self { Self::Member(value) } } impl From for CrateEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn from(value: UserEventMessage) -> Self { @@ -131,7 +155,7 @@ where impl From> for CrateEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn from(value: QueryEvent) -> Self { @@ -141,7 +165,7 @@ where impl From<(InternalQueryEvent, QueryEvent)> for CrateEvent where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn from(value: (InternalQueryEvent, QueryEvent)) -> Self { diff --git a/serf-core/src/key_manager.rs b/serf-core/src/key_manager.rs index a67f139..4c2135d 100644 --- a/serf-core/src/key_manager.rs +++ b/serf-core/src/key_manager.rs @@ -3,14 +3,9 @@ use std::{collections::HashMap, sync::OnceLock}; use async_channel::Receiver; use async_lock::RwLock; use futures::StreamExt; -use memberlist_core::{ - CheapClone, - bytes::{BufMut, BytesMut}, - proto::SecretKey, - tracing, - transport::{AddressResolver, Transport}, -}; -use smol_str::SmolStr; +use memberlist_core::{CheapClone, proto::SecretKey, tracing, transport::Transport}; +use serf_proto::MessageRef; +use smol_str::{SmolStr, format_smolstr}; use crate::event::{ INTERNAL_INSTALL_KEY, INTERNAL_LIST_KEYS, INTERNAL_REMOVE_KEY, INTERNAL_USE_KEY, @@ -84,7 +79,7 @@ pub struct KeyRequestOptions { /// encryption keyring changes across a cluster. pub struct KeyManager where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { serf: OnceLock>, @@ -94,7 +89,7 @@ where impl KeyManager where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { pub(crate) fn new() -> Self { @@ -193,7 +188,7 @@ where if let Some(opts) = opts { q_param.relay_factor = opts.relay_factor; } - let qresp: QueryResponse::ResolvedAddress> = serf + let qresp: QueryResponse = serf .internal_query(SmolStr::new(ty), buf, Some(q_param), event) .await?; @@ -222,7 +217,7 @@ where async fn stream_key_response( &self, - ch: Receiver::ResolvedAddress>>, + ch: Receiver>, ) -> KeyResponse { let mut resp = KeyResponse { num_nodes: self.serf.get().unwrap().num_members().await, @@ -253,16 +248,14 @@ where continue; } - let node_response = match decode_message(MessageType::KeyResponse, &r.payload[1..]) { - Ok((_, nr)) => match nr { - SerfMessage::KeyResponse(kr) => kr, + let node_response = match serf_proto::decode_message::(&r.payload) + { + Ok(msg) => match msg { + MessageRef::KeyResponse(kr) => kr, msg => { resp.messages.insert( r.from.id().cheap_clone(), - SmolStr::new(format!( - "Invalid key query response type: {:?}", - msg.ty().as_str() - )), + format_smolstr!("Invalid key query response type: {}", msg.ty()), ); resp.num_err += 1; @@ -286,27 +279,47 @@ where } }; - if !node_response.result { - resp - .messages - .insert(r.from.id().cheap_clone(), node_response.message); + if !node_response.result() { + resp.messages.insert( + r.from.id().cheap_clone(), + SmolStr::new(node_response.message()), + ); resp.num_err += 1; - } else if node_response.result && node_response.message.is_empty() { - tracing::warn!("serf: {}", node_response.message); - resp - .messages - .insert(r.from.id().cheap_clone(), node_response.message); + } else if node_response.result() && node_response.message().is_empty() { + tracing::warn!("serf: {}", node_response.message()); + resp.messages.insert( + r.from.id().cheap_clone(), + SmolStr::new(node_response.message()), + ); } // Currently only used for key list queries, this adds keys to a counter // and increments them for each node response which contains them. - for k in node_response.keys { - let count = resp.keys.entry(k).or_insert(0); - *count += 1; + let res = node_response + .keys() + .iter::() + .try_for_each(|res| { + res.map(|k| { + let count = resp.keys.entry(k).or_insert(0); + *count += 1; + }) + }); + + if let Err(e) = res { + resp.messages.insert( + r.from.id().cheap_clone(), + SmolStr::new(format!("Failed to decode key query response: {:?}", e)), + ); + resp.num_err += 1; + + if resp.num_resp == resp.num_nodes { + return resp; + } + continue; } - if let Some(pk) = node_response.primary_key { - let ctr = resp.primary_keys.entry(pk).or_insert(0); + if let Some(pk) = node_response.primary_key() { + let ctr = resp.primary_keys.entry(*pk).or_insert(0); *ctr += 1; } diff --git a/serf-core/src/serf.rs b/serf-core/src/serf.rs index d63e95d..ee72a68 100644 --- a/serf-core/src/serf.rs +++ b/serf-core/src/serf.rs @@ -130,40 +130,24 @@ where pub(crate) struct SerfCore> where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { pub(crate) clock: LamportClock, pub(crate) event_clock: LamportClock, pub(crate) query_clock: LamportClock, - broadcasts: Arc< - TransmitLimitedQueue< - SerfBroadcast, - NumMembers::ResolvedAddress>, - >, - >, - event_broadcasts: Arc< - TransmitLimitedQueue< - SerfBroadcast, - NumMembers::ResolvedAddress>, - >, - >, - query_broadcasts: Arc< - TransmitLimitedQueue< - SerfBroadcast, - NumMembers::ResolvedAddress>, - >, - >, + broadcasts: Arc>>, + event_broadcasts: Arc>>, + query_broadcasts: Arc>>, pub(crate) memberlist: Memberlist>, - pub(crate) members: - Arc::ResolvedAddress>>>, + pub(crate) members: Arc>>, event_tx: async_channel::Sender>, pub(crate) event_join_ignore: AtomicBool, pub(crate) event_core: RwLock, - query_core: Arc::ResolvedAddress>>>, + query_core: Arc>>, handles: AtomicRefCell< FuturesUnordered<<::Spawner as AsyncSpawner>::JoinHandle<()>>, >, @@ -190,7 +174,7 @@ where #[repr(transparent)] pub struct Serf> where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { pub(crate) inner: Arc>, @@ -198,7 +182,7 @@ where impl Clone for Serf where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn clone(&self) -> Self { diff --git a/serf-core/src/serf/api.rs b/serf-core/src/serf/api.rs index efcdf19..f61c7e7 100644 --- a/serf-core/src/serf/api.rs +++ b/serf-core/src/serf/api.rs @@ -58,7 +58,7 @@ where impl Serf where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Creates a new Serf instance with the given transport and options. @@ -104,7 +104,7 @@ where /// Returns the local node's ID and the advertised address #[inline] - pub fn advertise_node(&self) -> Node::ResolvedAddress> { + pub fn advertise_node(&self) -> Node { self.inner.memberlist.advertise_node() } @@ -134,9 +134,7 @@ where /// Returns a point-in-time snapshot of the members of this cluster. #[inline] - pub async fn members( - &self, - ) -> OneOrMore::ResolvedAddress>> { + pub async fn members(&self) -> OneOrMore> { self .inner .members @@ -202,9 +200,7 @@ where /// Returns the Member information for the local node #[inline] - pub async fn local_member( - &self, - ) -> Member::ResolvedAddress> { + pub async fn local_member(&self) -> Member { self .inner .members @@ -290,7 +286,7 @@ where self.inner.event_clock.increment(); // Process update locally - self.handle_user_event(msg).await; + self.handle_user_event(either::Either::Right(msg)).await; self .inner @@ -312,8 +308,7 @@ where name: impl Into, payload: impl Into, params: Option>, - ) -> Result::ResolvedAddress>, Error> - { + ) -> Result, Error> { self .query_in(name.into(), payload.into(), params, None) .await @@ -326,7 +321,7 @@ where &self, node: Node>, ignore_old: bool, - ) -> Result::ResolvedAddress>, Error> { + ) -> Result, Error> { // Do a quick state check let current_state = self.state(); if current_state != SerfState::Alive { diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index c17aac4..b1870af 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -1,15 +1,18 @@ use std::time::Duration; +use either::Either; use futures::{FutureExt, StreamExt}; use memberlist_core::{ CheapClone, - bytes::{BufMut, Bytes, BytesMut}, + agnostic_lite::AfterHandle, + bytes::Bytes, delegate::EventDelegate, - proto::{Meta, NodeState, OneOrMore, TinyVec}, + proto::{Data, Meta, NodeState, OneOrMore, TinyVec}, tracing, transport::{MaybeResolvedAddress, Node}, }; use rand::{Rng, SeedableRng}; +use serf_proto::{MessageRef, QueryMessageRef, QueryResponseMessageRef, Tags, UserEventMessageRef}; use smol_str::SmolStr; use crate::{ @@ -37,7 +40,7 @@ use super::*; impl Serf where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { #[cfg(feature = "test")] @@ -72,8 +75,9 @@ where { let tags = opts.tags.load(); if !tags.as_ref().is_empty() { - let len = tags_encoded_len(&tags); - if len > Meta::MAX_SIZE { + let len = tags.encoded_len_with_length_delimited(); + let meta_encoded_len = 1 + (len as u32).encoded_len() + len; + if meta_encoded_len > Meta::MAX_SIZE { return Err(Error::tags_too_large(len)); } } @@ -126,7 +130,7 @@ where // Try access the snapshot let (old_clock, old_event_clock, old_query_clock, event_tx, alive_nodes, handle) = if let Some(sp) = opts.snapshot_path.as_ref() { - let rs = open_and_replay_snapshot::<_, _, D>(sp, opts.rejoin_after_leave)?; + let rs = open_and_replay_snapshot(sp, opts.rejoin_after_leave)?; let old_clock = rs.last_clock; let old_event_clock = rs.last_event_clock; let old_query_clock = rs.last_query_clock; @@ -383,7 +387,7 @@ where // Process update locally self.handle_node_join_intent(&msg).await; - let msg = SerfMessage::Join(msg); + let msg = serf_proto::Encodable::encode_to_bytes(&msg)?; // Start broadcasting the update if let Err(e) = self.broadcast(msg, None).await { tracing::warn!(err=%e, "serf: failed to broadcast join intent"); @@ -468,7 +472,7 @@ where return Ok(()); } - let msg = SerfMessage::Leave(msg); + let msg = serf_proto::Encodable::encode_to_bytes(&msg)?; // Broadcast the remove let (ntx, nrx) = async_channel::bounded(1); self.broadcast(msg, Some(ntx)).await?; @@ -483,12 +487,12 @@ where struct Reaper where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { coord_core: Option>>, memberlist: Memberlist>, - members: Arc::ResolvedAddress>>>, + members: Arc>>, event_tx: async_channel::Sender>, shutdown_rx: async_channel::Receiver<()>, reap_interval: Duration, @@ -555,7 +559,7 @@ macro_rules! reap { impl Reaper where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { async fn run(self) { @@ -589,7 +593,7 @@ where async fn reap_failed( local_id: &T::Id, - old: &mut Members::ResolvedAddress>, + old: &mut Members, event_tx: &async_channel::Sender>, reconnector: Option<&D>, coord: Option<&CoordCore>, @@ -600,7 +604,7 @@ where async fn reap_left( local_id: &T::Id, - old: &mut Members::ResolvedAddress>, + old: &mut Members, event_tx: &async_channel::Sender>, reconnector: Option<&D>, coord: Option<&CoordCore>, @@ -613,9 +617,9 @@ where struct Reconnector where T: Transport, - D: Delegate::ResolvedAddress>, + D: Delegate, { - members: Arc::ResolvedAddress>>>, + members: Arc>>, memberlist: Memberlist>, shutdown_rx: async_channel::Receiver<()>, reconnect_interval: Duration, @@ -623,11 +627,11 @@ where impl Reconnector where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { fn spawn(self) -> <::Spawner as AsyncSpawner>::JoinHandle<()> { - let mut rng = rand::rngs::StdRng::from_rng(rand::thread_rng()).unwrap(); + let mut rng = rand::rngs::StdRng::from_rng(&mut rand::rng()); ::spawn(async move { let tick = ::interval(self.reconnect_interval); @@ -657,7 +661,7 @@ where } // Select a random member to try and join - let idx: usize = rng.gen_range(0..num_failed); + let idx: usize = rng.random_range(0..num_failed); let member = &mu.failed_members[idx]; let (id, address) = member.member.node().cheap_clone().into_components(); @@ -743,38 +747,58 @@ where // ---------------------------------Hanlders Methods------------------------------- impl Serf where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Called when a user event broadcast is /// received. Returns if the message should be rebroadcast. - pub(crate) async fn handle_user_event(&self, msg: UserEventMessage) -> bool { + pub(crate) async fn handle_user_event( + &self, + msg: Either, UserEventMessage>, + ) -> bool { + let (ltime, name) = match &msg { + Either::Left(msg) => (msg.ltime(), msg.name()), + Either::Right(msg) => (msg.ltime, msg.name.as_str()), + }; + // Witness a potentially newer time - self.inner.event_clock.witness(msg.ltime); + self.inner.event_clock.witness(ltime); let mut el = self.inner.event_core.write().await; // Ignore if it is before our minimum event time - if msg.ltime < el.min_time { + if ltime < el.min_time { return false; } // Check if this message is too old let bltime = LamportTime::new(el.buffer.len() as u64); let cur_time = self.inner.event_clock.time(); - if cur_time > bltime && msg.ltime < cur_time - bltime { + if cur_time > bltime && ltime < cur_time - bltime { tracing::warn!( "serf: received old event {} from time {} (current: {})", - msg.name, - msg.ltime, + name, + ltime, cur_time ); return false; } // Check if we've already seen this - let idx = u64::from(msg.ltime % bltime) as usize; + let idx = u64::from(ltime % bltime) as usize; let seen: Option<&mut UserEvents> = el.buffer[idx].as_mut(); + + let msg = match msg { + Either::Left(msg) => match UserEventMessage::from_ref(msg) { + Ok(msg) => msg, + Err(e) => { + tracing::warn!("serf: failed to decode user event message: {}", e); + return false; + } + }, + Either::Right(msg) => msg, + }; + let user_event = UserEvent { name: msg.name.clone(), payload: msg.payload.clone(), @@ -788,7 +812,7 @@ where seen.events.push(user_event); } else { el.buffer[idx] = Some(UserEvents { - ltime: msg.ltime, + ltime, events: OneOrMore::from(user_event), }); } @@ -819,20 +843,26 @@ where pub(crate) fn query_event( &self, - q: QueryMessage::ResolvedAddress>, + ltime: LamportTime, + name: SmolStr, + payload: Bytes, + timeout: Duration, + id: u32, + from: Node, + relay_factor: u8, ) -> QueryEvent { QueryEvent { - ltime: q.ltime, - name: q.name, - payload: q.payload, + ltime, + name, + payload, ctx: Arc::new(QueryContext { - query_timeout: q.timeout, + query_timeout: timeout, span: Mutex::new(Some(Epoch::now())), this: self.clone(), }), - id: q.id, - from: q.from, - relay_factor: q.relay_factor, + id, + from, + relay_factor, } } @@ -842,8 +872,7 @@ where payload: Bytes, params: Option>, ty: InternalQueryEvent, - ) -> Result::ResolvedAddress>, Error> - { + ) -> Result, Error> { self.query_in(name, payload, params, Some(ty)).await } @@ -853,8 +882,7 @@ where payload: Bytes, params: Option>, ty: Option>, - ) -> Result::ResolvedAddress>, Error> - { + ) -> Result, Error> { // Provide default parameters if none given. let params = match params { Some(params) if params.timeout != Duration::ZERO => params, @@ -868,9 +896,6 @@ where // Get the local node let local = self.inner.memberlist.advertise_node(); - // Encode the filters - let filters = params.encode_filters::()?; - // Setup the flags let flags = if params.request_ack { QueryFlag::ACK @@ -883,7 +908,7 @@ where ltime: self.inner.query_clock.time(), id: rand::random(), from: local.cheap_clone(), - filters, + filters: params.filters, flags, relay_factor: params.relay_factor, timeout: params.timeout, @@ -908,7 +933,7 @@ where .await; // Process query locally - self.handle_query(q, ty).await; + self.handle_query(Either::Right(q), ty).await; // Start broadcasting the event self @@ -927,7 +952,7 @@ where pub(crate) async fn register_query_response( &self, timeout: Duration, - resp: QueryResponse::ResolvedAddress>, + resp: QueryResponse, ) { let tresps = self.inner.query_core.clone(); let mut resps = self.inner.query_core.write().await; @@ -950,17 +975,41 @@ where /// received. Returns if the message should be rebroadcast. pub(crate) async fn handle_query( &self, - q: QueryMessage::ResolvedAddress>, + q: Either< + as Data>::Ref<'_>, + QueryMessage, + >, ty: Option>, - ) -> bool { + ) -> Result { + let (qm_ltime, qm_id, qm_timeout, no_broadcast, ack, name, filters) = match q.as_ref() { + Either::Left(q) => ( + q.ltime(), + q.id(), + q.timeout(), + q.no_broadcast(), + q.ack(), + q.name(), + Either::Left(*q.filters()), + ), + Either::Right(q) => ( + q.ltime, + q.id(), + q.timeout(), + q.no_broadcast(), + q.ack(), + q.name.as_str(), + Either::Right(q.filters.as_slice()), + ), + }; + // Witness a potentially newer time - self.inner.query_clock.witness(q.ltime); + self.inner.query_clock.witness(qm_ltime); let mut query = self.inner.query_core.write().await; // Ignore if it is before our minimum query time - if q.ltime < query.min_time { - return false; + if qm_ltime < query.min_time { + return Ok(false); } // Check if this message is too old @@ -969,30 +1018,30 @@ where if cur_time > q_time && q_time < cur_time - q_time { tracing::warn!( "serf: received old query {} from time {} (current: {})", - q.name, - q.ltime, + name, + qm_ltime, cur_time ); - return false; + return Ok(false); } // Check if we've already seen this - let idx = u64::from(q.ltime % q_time) as usize; + let idx = u64::from(qm_ltime % q_time) as usize; let seen = query.buffer[idx].as_mut(); if let Some(seen) = seen { - if seen.ltime == q.ltime { + if seen.ltime == qm_ltime { for &prev in seen.query_ids.iter() { - if q.id == prev { + if qm_id == prev { // Seen this ID already - return false; + return Ok(false); } } } - seen.query_ids.push(q.id); + seen.query_ids.push(qm_id); } else { query.buffer[idx] = Some(Queries { - ltime: q.ltime, - query_ids: MediumVec::from(q.id), + ltime: qm_ltime, + query_ids: MediumVec::from(qm_id), }); } @@ -1006,7 +1055,7 @@ where .increment(1); // TODO: how to avoid allocating here? - let named = format!("serf.queries.{}", q.name); + let named = format!("serf.queries.{}", name); metrics::counter!( named, self.inner.opts.memberlist_options.metric_labels().iter() @@ -1016,22 +1065,22 @@ where // Check if we should rebroadcast, this may be disabled by a flag let mut rebroadcast = true; - if q.no_broadcast() { + if no_broadcast { rebroadcast = false; } // Filter the query - if !self.should_process_query(&q.filters) { + if !self.should_process_query(filters)? { // Even if we don't process it further, we should rebroadcast, // since it is the first time we've seen this. - return rebroadcast; + return Ok(rebroadcast); } // Send ack if requested, without waiting for client to respond() - if q.ack() { + let (name, payload, from, relay_factor) = if ack { let ack = QueryResponseMessage { - ltime: q.ltime, - id: q.id, + ltime: qm_ltime, + id: qm_id, from: self.inner.memberlist.advertise_node(), flags: QueryFlag::ACK, payload: Bytes::new(), @@ -1039,24 +1088,59 @@ where match serf_proto::Encodable::encode_to_bytes(&ack) { Ok(raw) => { - if let Err(e) = self.inner.memberlist.send(q.from().address(), raw).await { + let (name, payload, from, relay_factor) = match q { + Either::Left(q) => ( + SmolStr::new(q.name()), + Bytes::copy_from_slice(q.payload()), + Node::from_ref(*q.from())?, + q.relay_factor(), + ), + Either::Right(q) => (q.name, q.payload, q.from, q.relay_factor), + }; + + if let Err(e) = self.inner.memberlist.send(from.address(), raw).await { tracing::error!(err=%e, "serf: failed to send ack"); } - if let Err(e) = self - .relay_response(q.relay_factor, q.from.clone(), ack) - .await - { + if let Err(e) = self.relay_response(relay_factor, from.clone(), ack).await { tracing::error!(err=%e, "serf: failed to relay ack"); } + (name, payload, from, relay_factor) } Err(e) => { tracing::error!(err=%e, "serf: failed to format ack"); + match q { + Either::Left(q) => ( + SmolStr::new(q.name()), + Bytes::copy_from_slice(q.payload()), + Node::from_ref(*q.from())?, + q.relay_factor(), + ), + Either::Right(q) => (q.name, q.payload, q.from, q.relay_factor), + } } } - } + } else { + match q { + Either::Left(q) => ( + SmolStr::new(q.name()), + Bytes::copy_from_slice(q.payload()), + Node::from_ref(*q.from())?, + q.relay_factor(), + ), + Either::Right(q) => (q.name, q.payload, q.from, q.relay_factor), + } + }; - let ev = self.query_event(q); + let ev = self.query_event( + qm_ltime, + name, + payload, + qm_timeout, + qm_id, + from, + relay_factor, + ); if let Err(e) = self .inner @@ -1070,15 +1154,15 @@ where tracing::error!(err=%e, "serf: failed to send query"); } - rebroadcast + Ok(rebroadcast) } /// Called when a query response is /// received. pub(crate) async fn handle_query_response( &self, - resp: QueryResponseMessage::ResolvedAddress>, - ) { + resp: as Data>::Ref<'_>, + ) -> Result<(), memberlist_core::proto::DecodeError> { // Look for a corresponding QueryResponse let qc = self .inner @@ -1086,19 +1170,21 @@ where .read() .await .responses - .get(&resp.ltime) + .get(&resp.ltime()) .cloned(); if let Some(query) = qc { // Verify the ID matches - if query.id != resp.id { + if query.id != resp.id() { tracing::warn!( "serf: query reply ID mismatch (local: {}, response: {})", query.id, - resp.id + resp.id() ); - return; + return Ok(()); } + let resp = QueryResponseMessage::::from_ref(resp)?; + query .handle_query_response::( resp, @@ -1109,20 +1195,19 @@ where .await; } else { tracing::warn!( - "serf: reply for non-running query (LTime: {}, ID: {}) From: {}", - resp.ltime, - resp.id, - resp.from + "serf: reply for non-running query (LTime: {}, ID: {}) From: {:?}", + resp.ltime(), + resp.id(), + resp.from() ); } + + Ok(()) } /// Called when a node join event is received /// from memberlist. - pub(crate) async fn handle_node_join( - &self, - n: Arc::ResolvedAddress>>, - ) { + pub(crate) async fn handle_node_join(&self, n: Arc>) { let mut members = self.inner.members.write().await; #[cfg(any(test, feature = "test"))] @@ -1136,7 +1221,7 @@ where let node = n.node(); let tags = if !n.meta().is_empty() { - match decode_tags(n.meta()) { + match ::decode(n.meta()) { Ok((readed, tags)) => { tracing::trace!(read = %readed, tags=?tags, "serf: decode tags successfully"); tags @@ -1291,10 +1376,7 @@ where } } - pub(crate) async fn handle_node_leave( - &self, - n: Arc::ResolvedAddress>>, - ) { + pub(crate) async fn handle_node_leave(&self, n: Arc>) { let mut members = self.inner.members.write().await; let Some(member_state) = members.states.get_mut(n.id()) else { @@ -1495,11 +1577,8 @@ where /// Called when a node meta data update /// has taken place - pub(crate) async fn handle_node_update( - &self, - n: Arc::ResolvedAddress>>, - ) { - let tags = match decode_tags(n.meta()) { + pub(crate) async fn handle_node_update(&self, n: Arc>) { + let tags = match ::decode(n.meta()) { Ok((readed, tags)) => { tracing::trace!(read = %readed, tags=?tags, "serf: decode tags successfully"); tags @@ -1552,8 +1631,8 @@ where /// erases a member from the list of members pub(crate) async fn handle_prune( &self, - member: &MemberState::ResolvedAddress>, - members: &mut Members::ResolvedAddress>, + member: &MemberState, + members: &mut Members, ) { let ms = member.member.status; if ms == MemberStatus::Leaving { @@ -1582,8 +1661,8 @@ where /// will reject the "new" node mapping, but we can still be notified. pub(crate) async fn handle_node_conflict( &self, - existing: Arc::ResolvedAddress>>, - other: Arc::ResolvedAddress>>, + existing: Arc>, + other: Arc>, ) { // Log a basic warning if the node is not us... if existing.id() != self.inner.memberlist.local_id() { @@ -1617,13 +1696,14 @@ where // Get the local node let local_id = self.inner.memberlist.local_id(); let local_advertise_addr = self.inner.memberlist.advertise_address(); - let encoded_id_len = id_encoded_len(local_id); - let mut payload = vec![0u8; encoded_id_len]; - if let Err(e) = encode_id(local_id, &mut payload) { - tracing::error!(err=%e, "serf: failed to encode local id"); - return; - } + let payload = match local_id.encode_to_bytes() { + Ok(id) => id, + Err(e) => { + tracing::error!(err=%e, "serf: failed to encode local id"); + return; + } + }; // Start an id resolution query let ty = InternalQueryEvent::Conflict(local_id.clone()); @@ -1645,29 +1725,29 @@ where // Gather responses let resp_rx = resp.response_rx(); while let Ok(r) = resp_rx.recv().await { - // Decode the response - if r.payload.is_empty() || r.payload[0] != MessageType::ConflictResponse as u8 { - tracing::warn!( - "serf: invalid conflict query response type: {:?}", - r.payload.as_ref() - ); - continue; - } - - match decode_message(MessageType::ConflictResponse, &r.payload[1..]) { - Ok((_, decoded)) => { - match decoded { - SerfMessage::ConflictResponse(member) => { + let res = serf_proto::decode_message::(&r.payload); + match res { + Ok(msg) => { + match msg { + MessageRef::ConflictResponse(resp) => { // Update the counters responses += 1; - if member.node.address().eq(local_advertise_addr) { - matching += 1; + match ::from_ref(*resp.member().node().address()) { + Ok(addr) => { + if addr.eq(local_advertise_addr) { + matching += 1; + } + } + Err(e) => { + tracing::error!(err=%e, "serf: failed to decode conflict query response"); + continue; + } } } msg => { tracing::warn!( - "serf: invalid conflict query response type: {}", - msg.ty().as_str() + type = %msg.ty(), + "serf: invalid conflict query response type", ); continue; } diff --git a/serf-core/src/serf/base/tests.rs b/serf-core/src/serf/base/tests.rs index 8073112..024efe4 100644 --- a/serf-core/src/serf/base/tests.rs +++ b/serf-core/src/serf/base/tests.rs @@ -42,7 +42,7 @@ fn test_config() -> Options { async fn wait_until_num_nodes(desired_nodes: usize, serfs: &[Serf]) where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { let start = Epoch::now(); @@ -68,7 +68,7 @@ where async fn wait_until_intent_queue_len(desired_len: usize, serfs: &[Serf]) where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { let start = Epoch::now(); @@ -104,7 +104,7 @@ async fn test_events( node: T::Id, expected: Vec, ) where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { let mut actual = Vec::with_capacity(expected.len()); @@ -147,7 +147,7 @@ async fn test_user_events( expected_name: Vec, expected_payload: Vec, ) where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { let mut actual_name = Vec::with_capacity(expected_name.len()); @@ -182,7 +182,7 @@ async fn test_query_events( expected_name: Vec, expected_payload: Vec, ) where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { let mut actual_name = Vec::with_capacity(expected_name.len()); diff --git a/serf-core/src/serf/base/tests/serf.rs b/serf-core/src/serf/base/tests/serf.rs index 08209e6..6f5366d 100644 --- a/serf-core/src/serf/base/tests/serf.rs +++ b/serf-core/src/serf/base/tests/serf.rs @@ -54,7 +54,7 @@ fn test_member_status( /// Unit tests for the get queue max pub async fn serf_get_queue_max( transport_opts: T::Options, - mut get_addr: impl FnMut(usize) -> ::ResolvedAddress, + mut get_addr: impl FnMut(usize) -> T::ResolvedAddress, ) where T: Transport, T::Options: Clone, @@ -161,7 +161,7 @@ pub async fn serf_get_queue_max( pub async fn serf_update( transport_opts1: T::Options, transport_opts2: T::Options, - get_transport: impl FnOnce(T::Id, ::ResolvedAddress) -> F + Copy, + get_transport: impl FnOnce(T::Id, T::ResolvedAddress) -> F + Copy, ) where T: Transport, T::Options: Clone, diff --git a/serf-core/src/serf/base/tests/serf/event.rs b/serf-core/src/serf/base/tests/serf/event.rs index c4dae30..d4e6ce9 100644 --- a/serf-core/src/serf/base/tests/serf/event.rs +++ b/serf-core/src/serf/base/tests/serf/event.rs @@ -248,7 +248,7 @@ pub async fn serf_events_leave_avoid_infinite_rebroadcast( transport_opts2: T::Options, transport_opts3: T::Options, transport_opts4: T::Options, - get_transport_opts: impl FnOnce(T::Id, ::ResolvedAddress) -> F + Copy, + get_transport_opts: impl FnOnce(T::Id, T::ResolvedAddress) -> F + Copy, ) where T: Transport, F: core::future::Future, @@ -652,7 +652,7 @@ where /// Unit tests for the query old message pub async fn query_old_message( transport_opts: T::Options, - from: Node::ResolvedAddress>, + from: Node, ) where T: Transport, { @@ -689,7 +689,7 @@ pub async fn query_old_message( /// Unit tests for the query same clock pub async fn query_same_clock( transport_opts: T::Options, - from: Node::ResolvedAddress>, + from: Node, ) where T: Transport, { diff --git a/serf-core/src/serf/base/tests/serf/join.rs b/serf-core/src/serf/base/tests/serf/join.rs index 08916ce..1aca9ae 100644 --- a/serf-core/src/serf/base/tests/serf/join.rs +++ b/serf-core/src/serf/base/tests/serf/join.rs @@ -37,7 +37,7 @@ where /// Unit tests for the join intent old message pub async fn join_intent_old_message( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -89,7 +89,7 @@ pub async fn join_intent_old_message( /// Unit tests for the join intent newer pub async fn join_intent_newer( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -140,7 +140,7 @@ pub async fn join_intent_newer( /// Unit tests for the join intent reset leaving pub async fn join_intent_reset_leaving( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -272,7 +272,7 @@ where /// Unit tests for the join pending intent logic pub async fn join_pending_intent( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -312,7 +312,7 @@ pub async fn join_pending_intent( /// Unit tests for the join pending intent logic pub async fn join_pending_intents( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { diff --git a/serf-core/src/serf/base/tests/serf/leave.rs b/serf-core/src/serf/base/tests/serf/leave.rs index aa705d9..62d4208 100644 --- a/serf-core/src/serf/base/tests/serf/leave.rs +++ b/serf-core/src/serf/base/tests/serf/leave.rs @@ -34,7 +34,7 @@ where /// Unit tests for the leave intent old message pub async fn leave_intent_old_message( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -86,7 +86,7 @@ pub async fn leave_intent_old_message( /// Unit tests for the leave intent newer pub async fn leave_intent_newer( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -417,7 +417,7 @@ pub async fn serf_leave_rejoin_different_role( pub async fn serf_leave_snapshot_recovery( transport_opts1: T::Options, transport_opts2: T::Options, - get_transport: impl FnOnce(T::Id, ::ResolvedAddress) -> F + Copy, + get_transport: impl FnOnce(T::Id, T::ResolvedAddress) -> F + Copy, ) where T: Transport, F: core::future::Future, diff --git a/serf-core/src/serf/base/tests/serf/reap.rs b/serf-core/src/serf/base/tests/serf/reap.rs index 3251f16..669a867 100644 --- a/serf-core/src/serf/base/tests/serf/reap.rs +++ b/serf-core/src/serf/base/tests/serf/reap.rs @@ -40,7 +40,7 @@ where /// Unit test for reap handler pub async fn serf_reap_handler( opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -131,7 +131,7 @@ pub async fn serf_reap_handler( } /// Unit test for reap -pub async fn serf_reap(opts: T::Options, addr: ::ResolvedAddress) +pub async fn serf_reap(opts: T::Options, addr: T::ResolvedAddress) where T: Transport, { diff --git a/serf-core/src/serf/base/tests/serf/reconnect.rs b/serf-core/src/serf/base/tests/serf/reconnect.rs index a204ba0..5e3dcf4 100644 --- a/serf-core/src/serf/base/tests/serf/reconnect.rs +++ b/serf-core/src/serf/base/tests/serf/reconnect.rs @@ -10,7 +10,7 @@ use super::*; pub async fn serf_reconnect( transport_opts1: T::Options, transport_opts2: T::Options, - get_transport: impl FnOnce(T::Id, ::ResolvedAddress) -> F + Copy, + get_transport: impl FnOnce(T::Id, T::ResolvedAddress) -> F + Copy, ) where T: Transport, F: core::future::Future, @@ -74,7 +74,7 @@ pub async fn serf_reconnect( pub async fn serf_reconnect_same_ip( transport_opts1: T::Options, transport2_id: T::Id, - get_transport: impl FnOnce(T::Id, ::ResolvedAddress) -> F + Copy, + get_transport: impl FnOnce(T::Id, T::ResolvedAddress) -> F + Copy, ) where T: Transport>, T::Options: Clone, diff --git a/serf-core/src/serf/base/tests/serf/snapshot.rs b/serf-core/src/serf/base/tests/serf/snapshot.rs index 4b38230..9547d9d 100644 --- a/serf-core/src/serf/base/tests/serf/snapshot.rs +++ b/serf-core/src/serf/base/tests/serf/snapshot.rs @@ -5,7 +5,7 @@ use super::*; /// Unit test for the snapshoter. pub async fn snapshoter( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -214,7 +214,7 @@ pub async fn snapshoter( /// Unit test for the snapshoter force compact. pub async fn snapshoter_force_compact( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -283,7 +283,7 @@ pub async fn snapshoter_force_compact( /// Unit test for the snapshoter leave pub async fn snapshoter_leave( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -394,7 +394,7 @@ pub async fn snapshoter_leave( /// Unit test for the snapshoter leave rejoin pub async fn snapshoter_leave_rejoin( transport_opts: T::Options, - addr: ::ResolvedAddress, + addr: T::ResolvedAddress, ) where T: Transport, { @@ -502,7 +502,7 @@ pub async fn snapshoter_leave_rejoin( pub async fn serf_snapshot_recovery( transport_opts1: T::Options, transport_opts2: T::Options, - get_transport: impl FnOnce(T::Id, ::ResolvedAddress) -> F + Copy, + get_transport: impl FnOnce(T::Id, T::ResolvedAddress) -> F + Copy, ) where T: Transport, T::Options: Clone, diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index bfefe0c..bb3d8b4 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -2,18 +2,22 @@ use crate::{ Serf, broadcast::SerfBroadcast, delegate::Delegate, - error::SerfError, + error::{Error, SerfDelegateError, SerfError}, event::QueryMessageExt, types::{ DelegateVersion, JoinMessage, LamportTime, LeaveMessage, Member, MemberStatus, - MemberlistDelegateVersion, MemberlistProtocolVersion, MessageType, ProtocolVersion, - PushPullMessageBorrow, SerfMessage, UserEventMessage, + MemberlistDelegateVersion, MemberlistProtocolVersion, MessageRef, MessageType, ProtocolVersion, + PushPullMessageBorrow, UserEventMessage, }, }; -use std::sync::{Arc, OnceLock, atomic::Ordering}; +use std::{ + borrow::Cow, + sync::{Arc, OnceLock, atomic::Ordering}, +}; use arc_swap::ArcSwap; +use either::Either; use indexmap::IndexSet; use memberlist_core::{ CheapClone, META_MAX_SIZE, @@ -22,11 +26,11 @@ use memberlist_core::{ AliveDelegate, ConflictDelegate, Delegate as MemberlistDelegate, EventDelegate, MergeDelegate as MemberlistMergeDelegate, NodeDelegate, PingDelegate, }, - proto::{Meta, NodeState, SmallVec, State, TinyVec}, + proto::{Data, Meta, NodeState, SmallVec, State, TinyVec}, tracing, transport::{AddressResolver, Transport}, }; -use serf_proto::Tags; +use serf_proto::{PushPullMessage, Tags}; // PingVersion is an internal version for the ping message, above the normal // versioning we get from the protocol version. This enables small updates @@ -41,7 +45,7 @@ pub(crate) trait MessageDropper: Send + Sync + 'static { /// The memberlist delegate for Serf. pub struct SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { serf: OnceLock>, @@ -58,7 +62,7 @@ where impl SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { pub(crate) fn new(d: Option, tags: Arc>) -> Self { @@ -110,14 +114,14 @@ where impl NodeDelegate for SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { async fn node_meta(&self, limit: usize) -> Meta { let tags = self.tags.load(); match tags.is_empty() { false => { - let encoded_len = tags_encoded_len(&tags); + let encoded_len = tags.encoded_len(); let limit = limit.min(Meta::MAX_SIZE); if encoded_len > limit { panic!( @@ -127,14 +131,8 @@ where } let mut role_bytes = vec![0; encoded_len]; - match encode_tags(&tags, &mut role_bytes) { + match tags.encode(&mut role_bytes) { Ok(len) => { - debug_assert_eq!( - len, encoded_len, - "expected encoded len {} mismatch the actual encoded len {}", - encoded_len, len - ); - if len > limit { panic!( "node tags {:?} exceeds length limit of {} bytes", @@ -154,9 +152,16 @@ where } } - async fn notify_message(&self, mut msg: Bytes) { + async fn notify_message(&self, buf: Cow<'_, [u8]>) { + fn to_owned(buf: Cow<'_, [u8]>) -> Bytes { + match buf { + Cow::Borrowed(buf) => Bytes::copy_from_slice(buf), + Cow::Owned(buf) => Bytes::from(buf), + } + } + // If we didn't actually receive any data, then ignore it. - if msg.is_empty() { + if buf.is_empty() { return; } @@ -172,14 +177,15 @@ where .metric_labels .iter() ) - .record(msg.len() as f64); + .record(buf.len() as f64); } let this = self.this(); - let mut rebroadcast = None; + let mut rebroadcast = false; let mut rebroadcast_queue = &this.inner.broadcasts; - match MessageType::try_from(msg[0]) { - Ok(ty) => { + let mut relay = None; + match serf_proto::decode_message::(buf.as_ref()) { + Ok(msg) => { #[cfg(any(test, feature = "test"))] { if let Some(ref dropper) = this.inner.memberlist.delegate().unwrap().message_dropper { @@ -189,130 +195,132 @@ where } } - match ty { - MessageType::Leave => match decode_message(ty, &msg[1..]) { - Ok((_, l)) => { - if let SerfMessage::Leave(l) = &l { - tracing::debug!("serf: leave message: {}", l.id()); - rebroadcast = this.handle_node_leave_intent(l).await.then(|| msg.clone()); - } else { - tracing::warn!("serf: receive unexpected message: {}", l.ty().as_str()); + match msg { + MessageRef::Leave(l) => { + tracing::debug!("serf: leave message: {:?}", l.id()); + // TODO(al8n): do not read to owned here + match as Data>::from_ref(l) { + Err(e) => { + tracing::error!(err=%e, "serf: failed to decode leave message"); } - } - Err(e) => { - tracing::warn!(err=%e, "serf: failed to decode message"); - } - }, - MessageType::Join => match decode_message(ty, &msg[1..]) { - Ok((_, j)) => { - if let SerfMessage::Join(j) = &j { - tracing::debug!("serf: join message: {}", j.id()); - rebroadcast = this.handle_node_join_intent(j).await.then(|| msg.clone()); - } else { - tracing::warn!("serf: receive unexpected message: {}", j.ty().as_str()); + Ok(l) => { + rebroadcast = this.handle_node_leave_intent(&l).await; } - } - Err(e) => { - tracing::warn!(err=%e, "serf: failed to decode message"); - } - }, - MessageType::UserEvent => match decode_message(ty, &msg[1..]) { - Ok((_, ue)) => { - if let SerfMessage::UserEvent(ue) = ue { - tracing::debug!("serf: user event message: {}", ue.name); - rebroadcast = this.handle_user_event(ue).await.then(|| msg.clone()); - rebroadcast_queue = &this.inner.event_broadcasts; - } else { - tracing::warn!("serf: receive unexpected message: {}", ue.ty().as_str()); + }; + } + MessageRef::Join(j) => { + tracing::debug!("serf: join message: {:?}", j.id()); + // TODO(al8n): do not read to owned here + + match as Data>::from_ref(j) { + Err(e) => { + tracing::error!(err=%e, "serf: failed to decode join message"); } - } - Err(e) => { - tracing::warn!(err=%e, "serf: failed to decode message"); - } - }, - MessageType::Query => match decode_message(ty, &msg[1..]) { - Ok((_, q)) => { - if let SerfMessage::Query(q) = q { - tracing::debug!("serf: query message: {}", q.name); - match q.decode_internal_query::() { - Some(Err(e)) => { - tracing::warn!(err=%e, "serf: failed to decode message"); - } - Some(Ok(res)) => { - rebroadcast = this.handle_query(q, Some(res)).await.then(|| msg.clone()); - rebroadcast_queue = &this.inner.query_broadcasts; - } - None => { - rebroadcast = this.handle_query(q, None).await.then(|| msg.clone()); - rebroadcast_queue = &this.inner.query_broadcasts; - } - }; - } else { - tracing::warn!("serf: receive unexpected message: {}", q.ty().as_str()); + Ok(j) => { + rebroadcast = this.handle_node_join_intent(&j).await; } - } - Err(e) => { - tracing::warn!(err=%e, "serf: failed to decode message"); - } - }, - MessageType::QueryResponse => match decode_message(ty, &msg[1..]) { - Ok((_, qr)) => { - if let SerfMessage::QueryResponse(qr) = qr { - tracing::debug!("serf: query response message: {}", qr.from); - this.handle_query_response(qr).await; - } else { - tracing::warn!("serf: receive unexpected message: {}", qr.ty().as_str()); + }; + } + MessageRef::UserEvent(ue) => { + tracing::debug!("serf: user event message: {}", ue.name()); + rebroadcast = this.handle_user_event(either::Either::Left(ue)).await; + rebroadcast_queue = &this.inner.event_broadcasts; + } + MessageRef::Query(q) => { + tracing::debug!("serf: query message: {}", q.name()); + match q.decode_internal_query() { + Some(Err(e)) => { + tracing::warn!(err=%e, "serf: failed to decode message"); } + Some(Ok(res)) => match this.handle_query(either::Either::Left(q), Some(res)).await { + Ok(val) => { + rebroadcast = val; + rebroadcast_queue = &this.inner.query_broadcasts; + } + Err(e) => { + tracing::warn!(err=%e, "serf: failed to decode query message"); + } + }, + None => match this.handle_query(either::Either::Left(q), None).await { + Ok(val) => { + rebroadcast = val; + rebroadcast_queue = &this.inner.query_broadcasts; + } + Err(e) => { + tracing::warn!(err=%e, "serf: failed to decode query message"); + } + }, } - Err(e) => { - tracing::warn!(err=%e, "serf: failed to decode message"); + } + MessageRef::QueryResponse(qr) => { + tracing::debug!("serf: query response message: {:?}", qr.from()); + if let Err(e) = this.handle_query_response(qr).await { + tracing::warn!(err=%e, "serf: failed to decode query response message"); } - }, - MessageType::Relay => match decode_node(&msg[1..]) { - Ok((consumed, n)) => { - tracing::debug!("serf: relay message",); - tracing::debug!("serf: relaying response to node: {}", n); - // + 1 for the message type byte - msg.advance(consumed + 1); - if let Err(e) = this.inner.memberlist.send(n.address(), msg.clone()).await { - tracing::error!(err=%e, "serf: failed to forwarding message to {}", n); + } + MessageRef::Relay { + node, + payload, + payload_offset, + } => { + tracing::debug!("serf: relaying response to node: {:?}", node); + match Data::from_ref(*node.address()) { + Err(e) => { + tracing::error!(err=%e, "serf: failed to encode address"); } + Ok(addr) => match buf { + Cow::Borrowed(_) => { + relay = Some((addr, Either::Left(Bytes::copy_from_slice(payload)))); + } + Cow::Owned(_) => { + relay = Some((addr, Either::Right((payload_offset, payload.len())))); + } + }, } - Err(e) => { - tracing::warn!(err=%e, "serf: failed to decode relay destination"); - } - }, - ty => { - tracing::warn!("serf: receive unexpected message: {}", ty.as_str()); + } + msg => { + tracing::warn!("serf: receive unexpected message type: {}", msg.ty()); } } } Err(e) => { - tracing::warn!(err=%e, "serf: receive unknown message type"); + tracing::warn!(err=%e, "serf: failed to decode message"); } } - if let Some(msg) = rebroadcast { + if rebroadcast { rebroadcast_queue .queue_broadcast(SerfBroadcast { - msg, + msg: to_owned(buf), notify_tx: None, }) .await; + } else if let Some((addr, payload)) = relay { + let msg = match payload { + Either::Left(p) => p, + Either::Right((offset, len)) => { + let mut buf = to_owned(buf); + buf.advance(offset); + buf.split_to(len) + } + }; + + if let Err(e) = this.inner.memberlist.send(&addr, msg).await { + tracing::error!(err=%e, "serf: failed to forwarding message to {}", addr); + } } } async fn broadcast_messages( &self, - overhead: usize, limit: usize, encoded_len: F, - ) -> TinyVec + ) -> impl Iterator + Send where - F: Fn(Bytes) -> (usize, Bytes) + Send, + F: Fn(Bytes) -> (usize, Bytes) + Send + Sync + 'static, { let this = self.this(); - let mut msgs = this.inner.broadcasts.get_broadcasts(overhead, limit).await; + let mut msgs = this.inner.broadcasts.get_broadcasts(limit).await; // Determine the bytes used already let mut bytes_used = 0; @@ -333,7 +341,7 @@ where let query_msgs = this .inner .query_broadcasts - .get_broadcasts(overhead, limit - bytes_used) + .get_broadcasts(limit - bytes_used) .await; for msg in query_msgs.iter() { let (encoded_len, _) = encoded_len(msg.clone()); @@ -352,7 +360,7 @@ where let event_msgs = this .inner .event_broadcasts - .get_broadcasts(overhead, limit - bytes_used) + .get_broadcasts(limit - bytes_used) .await; for msg in event_msgs.iter() { let (encoded_len, _) = encoded_len(msg.clone()); @@ -368,7 +376,7 @@ where } msgs.extend(query_msgs); msgs.extend(event_msgs); - msgs + msgs.into_iter() } async fn local_state(&self, _join: bool) -> Bytes { @@ -406,16 +414,19 @@ where } } - async fn merge_remote_state(&self, buf: Bytes, is_join: bool) { + async fn merge_remote_state(&self, buf: &[u8], is_join: bool) { if buf.is_empty() { tracing::error!("serf: remote state is zero bytes"); return; } // Check the message type - let Ok(ty) = MessageType::try_from(buf[0]) else { - tracing::error!("serf: remote state has bad type prefix {}", buf[0]); - return; + let msg = match serf_proto::decode_message::(buf) { + Ok(msg) => msg, + Err(e) => { + tracing::error!(err=%e, "serf: fail to decode remote state"); + return; + } }; #[cfg(any(test, feature = "test"))] @@ -434,108 +445,100 @@ where } } - match ty { - MessageType::PushPull => { - match decode_message(ty, &buf[1..]) { + match msg { + MessageRef::PushPull(pp) => { + let ltime = pp.ltime(); + let event_ltime = pp.event_ltime(); + let query_ltime = pp.query_ltime(); + let this = self.this(); + // Witness the Lamport clocks first. + // We subtract 1 since no message with that clock has been sent yet + if ltime > LamportTime::ZERO { + this.inner.clock.witness(ltime - LamportTime::new(1)); + } + if event_ltime > LamportTime::ZERO { + this + .inner + .event_clock + .witness(event_ltime - LamportTime::new(1)); + } + if query_ltime > LamportTime::ZERO { + this + .inner + .query_clock + .witness(query_ltime - LamportTime::new(1)); + } + + let pp = match as Data>::from_ref(pp) { + Ok(pp) => pp, Err(e) => { - tracing::error!(err=%e, "serf: failed to decode remote state"); + tracing::error!(err=%e, "serf: failed to decode push pull message"); + return; } - Ok((_, msg)) => { - match msg { - SerfMessage::PushPull(pp) => { - let this = self.this(); - // Witness the Lamport clocks first. - // We subtract 1 since no message with that clock has been sent yet - if pp.ltime > LamportTime::ZERO { - this.inner.clock.witness(pp.ltime - LamportTime::new(1)); - } - if pp.event_ltime > LamportTime::ZERO { - this - .inner - .event_clock - .witness(pp.event_ltime - LamportTime::new(1)); - } - if pp.query_ltime > LamportTime::ZERO { - this - .inner - .query_clock - .witness(pp.query_ltime - LamportTime::new(1)); - } + }; + + // Process the left nodes first to avoid the LTimes from incrementing + // in the wrong order. Note that we don't have the actual Lamport time + // for the leave message, so we go one past the join time, since the + // leave must have been accepted after that to get onto the left members + // list. If we didn't do this then the message would not get processed. + for node in &pp.left_members { + if let Some(<ime) = pp.status_ltimes.get(node) { + this + .handle_node_leave_intent(&LeaveMessage { + ltime: ltime + LamportTime::new(1), + id: node.cheap_clone(), + prune: false, + }) + .await; + } else { + tracing::error!( + "serf: {} is in left members, but cannot find the lamport time for it in status", + node + ); + } + } - // Process the left nodes first to avoid the LTimes from incrementing - // in the wrong order. Note that we don't have the actual Lamport time - // for the leave message, so we go one past the join time, since the - // leave must have been accepted after that to get onto the left members - // list. If we didn't do this then the message would not get processed. - for node in &pp.left_members { - if let Some(<ime) = pp.status_ltimes.get(node) { - this - .handle_node_leave_intent(&LeaveMessage { - ltime: ltime + LamportTime::new(1), - id: node.cheap_clone(), - prune: false, - }) - .await; - } else { - tracing::error!( - "serf: {} is in left members, but cannot find the lamport time for it in status", - node - ); - } - } + // Update any other LTimes + for (node, ltime) in pp.status_ltimes { + // Skip the left nodes + if pp.left_members.contains(&node) { + continue; + } - // Update any other LTimes - for (node, ltime) in pp.status_ltimes { - // Skip the left nodes - if pp.left_members.contains(&node) { - continue; - } - - // Create an artificial join message - this - .handle_node_join_intent(&JoinMessage { ltime, id: node }) - .await; - } + // Create an artificial join message + this + .handle_node_join_intent(&JoinMessage { ltime, id: node }) + .await; + } - // If we are doing a join, and eventJoinIgnore is set - // then we set the eventMinTime to the EventLTime. This - // prevents any of the incoming events from being processed - let event_join_ignore = this.inner.event_join_ignore.load(Ordering::Acquire); - if is_join && event_join_ignore { - let mut ec = this.inner.event_core.write().await; - if pp.event_ltime > ec.min_time { - ec.min_time = pp.event_ltime; - } - } + // If we are doing a join, and eventJoinIgnore is set + // then we set the eventMinTime to the EventLTime. This + // prevents any of the incoming events from being processed + let event_join_ignore = this.inner.event_join_ignore.load(Ordering::Acquire); + if is_join && event_join_ignore { + let mut ec = this.inner.event_core.write().await; + if event_ltime > ec.min_time { + ec.min_time = event_ltime; + } + } - // Process all the events - for events in pp.events { - match events { - Some(events) => { - for e in events.events { - this - .handle_user_event(UserEventMessage { - ltime: events.ltime, - name: e.name, - payload: e.payload, - cc: false, - }) - .await; - } - } - None => continue, - } - } - } - msg => { - tracing::error!("serf: remote state has bad type {}", msg.ty().as_str()); - } - } + // Process all the events + for events in pp.events { + for e in events.events { + this + .handle_user_event(either::Either::Right(UserEventMessage { + ltime: events.ltime, + name: e.name, + payload: e.payload, + cc: false, + })) + .await; } } } - ty => { - tracing::error!("serf: remote state has bad type {}", ty.as_str()); + msg => { + tracing::error!("serf: remote state has bad message type {}", msg.ty()); } } } @@ -543,11 +546,11 @@ where impl EventDelegate for SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Id = T::Id; - type Address = ::ResolvedAddress; + type Address = T::ResolvedAddress; async fn notify_join(&self, node: Arc>) { if let Some(serf) = self.serf.get() { @@ -566,11 +569,11 @@ where impl AliveDelegate for SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Id = T::Id; - type Address = ::ResolvedAddress; + type Address = T::ResolvedAddress; type Error = SerfDelegateError; async fn notify_alive( @@ -578,11 +581,11 @@ where node: Arc>, ) -> Result<(), Self::Error> { if let Some(ref d) = self.delegate { - let member = node_to_member::(node)?; + let member = node_to_member::(&node)?; return d - .notify_merge(TinyVec::from(member)) + .notify_merge(Arc::from_iter([member])) .await - .map_err(SerfDelegateError::merge); + .map_err(SerfDelegateError::Merge); } Ok(()) @@ -591,26 +594,26 @@ where impl MemberlistMergeDelegate for SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Id = T::Id; - type Address = ::ResolvedAddress; + type Address = T::ResolvedAddress; type Error = SerfDelegateError; async fn notify_merge( &self, - peers: SmallVec>>, + peers: Arc<[NodeState]>, ) -> Result<(), Self::Error> { if let Some(ref d) = self.delegate { let peers = peers - .into_iter() + .iter() .map(node_to_member::) - .collect::, _>>()?; + .collect::, _>>()?; return d .notify_merge(peers) .await - .map_err(SerfDelegateError::merge); + .map_err(SerfDelegateError::Merge); } Ok(()) } @@ -618,12 +621,12 @@ where impl ConflictDelegate for SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Id = T::Id; - type Address = ::ResolvedAddress; + type Address = T::ResolvedAddress; async fn notify_conflict( &self, @@ -636,12 +639,12 @@ where impl PingDelegate for SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Id = T::Id; - type Address = ::ResolvedAddress; + type Address = T::ResolvedAddress; async fn ack_payload(&self) -> Bytes { #[cfg(any(feature = "test", test))] @@ -664,9 +667,9 @@ where coord.portion.resize(len * 2, 0.0); // The rest of the message is the serialized coordinate. - let len = coordinate_encoded_len(&coord); + let len = coord.encoded_len(); buf.resize(len + 1, 0); - if let Err(e) = encode_coordinate(&coord, &mut buf[1..]) { + if let Err(e) = coord.encode(&mut buf[1..]) { panic!("failed to encode coordinate: {}", e); } return buf.freeze(); @@ -674,12 +677,12 @@ where if let Some(c) = self.this().inner.coord_core.as_ref() { let coord = c.client.get_coordinate(); - let encoded_len = coordinate_encoded_len(&coord) + 1; + let encoded_len = coord.encoded_len() + 1; let mut buf = BytesMut::with_capacity(encoded_len); buf.put_u8(PING_VERSION); buf.resize(encoded_len, 0); - if let Err(e) = encode_coordinate(&coord, &mut buf[1..]) { + if let Err(e) = coord.encode(&mut buf[1..]) { tracing::error!(err=%e, "serf: failed to encode coordinate"); } buf.into() @@ -708,7 +711,7 @@ where } // Process the remainder of the message as a coordinate. - let coord = match decode_coordinate(&payload[1..]) { + let coord = match ::decode(&payload[1..]) { Ok((readed, c)) => { tracing::trace!(read=%readed, coordinate=?c, "serf: decode coordinate successfully"); c @@ -764,28 +767,23 @@ where } } } - - #[inline] - fn disable_promised_pings(&self, _id: &Self::Id) -> bool { - false - } } impl MemberlistDelegate for SerfDelegate where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { type Id = T::Id; - type Address = ::ResolvedAddress; + type Address = T::ResolvedAddress; } fn node_to_member( - node: Arc::ResolvedAddress>>, -) -> Result::ResolvedAddress>, SerfDelegateError> + node: &NodeState, +) -> Result, SerfDelegateError> where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { let status = if node.state() == State::Left { @@ -796,25 +794,25 @@ where let meta = node.meta(); if meta.len() > META_MAX_SIZE { - return Err(SerfDelegateError::serf(SerfError::TagsTooLarge(meta.len()))); + return Err(SerfDelegateError::Serf(SerfError::TagsTooLarge(meta.len()))); } Ok(Member { node: node.node(), tags: if !node.meta().is_empty() { - decode_tags(node.meta()) + ::decode(node.meta()) .map(|(read, tags)| { tracing::trace!(read=%read, tags=?tags, "serf: decode tags successfully"); Arc::new(tags) }) - .map_err(SerfDelegateError::transform)? + .map_err(|e| SerfDelegateError::Serf(SerfError::from(e)))? } else { Default::default() }, status, protocol_version: ProtocolVersion::V1, delegate_version: DelegateVersion::V1, - memberlist_delegate_version: MemberlistDelegateVersion::V1, - memberlist_protocol_version: MemberlistProtocolVersion::V1, + memberlist_delegate_version: node.delegate_version(), + memberlist_protocol_version: node.protocol_version(), }) } diff --git a/serf-core/src/serf/internal_query.rs b/serf-core/src/serf/internal_query.rs index b00c645..2861b97 100644 --- a/serf-core/src/serf/internal_query.rs +++ b/serf-core/src/serf/internal_query.rs @@ -2,15 +2,15 @@ use async_channel::{Receiver, Sender, bounded}; use futures::FutureExt; use memberlist_core::{ agnostic_lite::{AsyncSpawner, RuntimeLite}, - bytes::{BufMut, Bytes, BytesMut}, + bytes::Bytes, tracing, - transport::{AddressResolver, Transport}, + transport::Transport, }; use crate::{ delegate::Delegate, event::{CrateEvent, InternalQueryEvent, QueryEvent}, - types::MessageType, + types::MessageRef, }; #[cfg(feature = "encryption")] @@ -29,7 +29,7 @@ const MIN_ENCODED_KEY_LENGTH: usize = 25; pub(crate) struct SerfQueries where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { in_rx: Receiver>, @@ -39,7 +39,7 @@ where impl SerfQueries where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { #[allow(clippy::new_ret_no_self)] @@ -179,9 +179,9 @@ where async fn handle_install_key(ev: impl AsRef> + Send) { let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = match decode_message(MessageType::KeyRequest, &q.payload[1..]) { - Ok((_, msg)) => match msg { - SerfMessage::KeyRequest(req) => req, + let req = match serf_proto::decode_message::(&q.payload) { + Ok(msg) => match msg { + MessageRef::KeyRequest(req) => req, msg => { tracing::error!( err = "unexpected message type", @@ -242,15 +242,11 @@ where let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = match decode_message(MessageType::KeyRequest, &q.payload[1..]) { - Ok((_, msg)) => match msg { - SerfMessage::KeyRequest(req) => req, + let req = match serf_proto::decode_message::(&q.payload) { + Ok(msg) => match msg { + MessageRef::KeyRequest(req) => req, msg => { - tracing::error!( - err = "unexpected message type", - "serf: {}", - msg.ty().as_str() - ); + tracing::error!(err = "unexpected message type", "serf: {}", msg.ty()); Self::send_key_response(q, &mut response).await; return; } @@ -311,15 +307,11 @@ where let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = match decode_message(MessageType::KeyRequest, &q.payload[1..]) { - Ok((_, msg)) => match msg { - SerfMessage::KeyRequest(req) => req, + let req = match serf_proto::decode_message::(&q.payload) { + Ok(msg) => match msg { + MessageRef::KeyRequest(req) => req, msg => { - tracing::error!( - err = "unexpected message type", - "serf: {}", - msg.ty().as_str() - ); + tracing::error!(err = "unexpected message type", "serf: {}", msg.ty()); Self::send_key_response(q, &mut response).await; return; } @@ -422,7 +414,7 @@ where ) -> Result< ( Bytes, - serf_proto::QueryResponseMessage::ResolvedAddress>, + serf_proto::QueryResponseMessage, ), Error, > { diff --git a/serf-core/src/serf/query.rs b/serf-core/src/serf/query.rs index 19cad1f..17340f9 100644 --- a/serf-core/src/serf/query.rs +++ b/serf-core/src/serf/query.rs @@ -6,21 +6,21 @@ use std::{ use async_channel::{Receiver, Sender}; use async_lock::RwLock; +use either::Either; use futures::{FutureExt, StreamExt, stream::FuturesUnordered}; use memberlist_core::{ CheapClone, - bytes::{BufMut, Bytes, BytesMut}, - proto::{OneOrMore, SmallVec, TinyVec}, + bytes::Bytes, + proto::{Data, RepeatedDecoder, SmallVec, TinyVec}, tracing, - transport::{AddressResolver, Id, Node, Transport}, + transport::{Node, Transport}, }; +use serf_proto::FilterRef; use crate::{ delegate::Delegate, error::Error, - types::{ - Filter, LamportTime, Member, MemberStatus, MessageType, QueryMessage, QueryResponseMessage, - }, + types::{Filter, LamportTime, Member, MemberStatus, QueryMessage, QueryResponseMessage}, }; use super::Serf; @@ -40,7 +40,7 @@ pub struct QueryParam { getter(const, attrs(doc = "Returns the filters of the query")), setter(attrs(doc = "Sets the filters of the query")) )] - filters: OneOrMore>, + filters: TinyVec>, /// If true, we are requesting an delivery acknowledgement from /// every node that meets the filter requirement. This means nodes @@ -245,7 +245,7 @@ impl QueryResponse { ) where I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug, A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug, - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { // Check if the query is closed @@ -308,7 +308,7 @@ impl QueryResponse { where I: Eq + std::hash::Hash + CheapClone + core::fmt::Debug, A: Eq + std::hash::Hash + CheapClone + core::fmt::Debug, - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { let mut c = self.inner.core.write().await; @@ -342,7 +342,7 @@ impl QueryResponse { where I: Eq + std::hash::Hash + CheapClone, A: Eq + std::hash::Hash + CheapClone, - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { let mut c = self.inner.core.write().await; @@ -410,7 +410,7 @@ fn random_members(k: usize, mut members: SmallVec>) -> SmallV impl Serf where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { /// Returns the default timeout value for a query @@ -429,78 +429,103 @@ where /// Used to return the default query parameters pub async fn default_query_param(&self) -> QueryParam { QueryParam { - filters: OneOrMore::new(), + filters: TinyVec::new(), request_ack: false, relay_factor: 0, timeout: self.default_query_timeout().await, } } - pub(crate) fn should_process_query(&self, filters: &[Bytes]) -> bool { - for filter in filters.iter() { - if filter.is_empty() { - tracing::warn!("serf: empty filter"); - return false; - } - - // Decode the filter - let filter = match decode_filter(filter) { - Ok((read, filter)) => { - tracing::trace!(read=%read, filter=?filter, "serf: decoded filter successully"); - filter - } - Err(err) => { - tracing::warn!( - err = %err, - "serf: failed to decode filter" - ); - return false; - } - }; - - match filter { - Filter::Id(nodes) => { - // Check if we are being targeted - let found = nodes - .iter() - .any(|n: &T::Id| n.eq(self.inner.memberlist.local_id())); - if !found { - return false; + pub(crate) fn should_process_query( + &self, + filters: Either, &[Filter]>, + ) -> Result { + match filters { + Either::Left(filters) => { + for filter in filters.iter::>() { + let filter = filter?; + match filter { + FilterRef::Id(ids) => { + // Check if we are being targeted + let mut found = false; + for id in ids.iter::() { + let id = id?; + if ::from_ref(id)?.eq(self.inner.memberlist.local_id()) { + found = true; + break; + } + } + if !found { + return Ok(false); + } + } + FilterRef::Tag(tag) => { + // Check if we match this regex + let tags = self.inner.opts.tags.load(); + if !tags.is_empty() { + if let Some(expr) = tags.get(tag.tag()) { + if let Some(re) = tag.expr() { + if !regex::Regex::new(re) + .map_err(|_| memberlist_core::proto::DecodeError::custom("invalid regex"))? + .is_match(expr) + { + return Ok(false); + } + } + } else { + return Ok(false); + } + } else { + return Ok(false); + } + } } } - Filter::Tag(tag) => { - // Check if we match this regex - let tags = self.inner.opts.tags.load(); - if !tags.is_empty() { - if let Some(expr) = tags.get(&tag) { - match regex::Regex::new(&fexpr) { - Ok(re) => { - if !re.is_match(expr) { - return false; + + Ok(true) + } + Either::Right(filters) => { + for filter in filters.iter() { + match &filter { + Filter::Id(nodes) => { + // Check if we are being targeted + let found = nodes + .iter() + .any(|n: &T::Id| n.eq(self.inner.memberlist.local_id())); + if !found { + return Ok(false); + } + } + Filter::Tag(tag) => { + // Check if we match this regex + let tags = self.inner.opts.tags.load(); + if !tags.is_empty() { + if let Some(expr) = tags.get(tag.tag()) { + if let Some(re) = tag.expr() { + if !re.is_match(expr) { + return Ok(false); + } } + } else { + return Ok(false); } - Err(err) => { - tracing::warn!(err=%err, "serf: failed to compile filter regex ({})", fexpr); - return false; - } + } else { + return Ok(false); } - } else { - return false; } - } else { - return false; + _ => {} } } + Ok(true) } } - true } pub(crate) async fn relay_response( &self, relay_factor: u8, - node: Node::ResolvedAddress>, - resp: QueryResponseMessage::ResolvedAddress>, + node: Node, + resp: QueryResponseMessage, ) -> Result<(), Error> { if relay_factor == 0 { return Ok(()); @@ -532,14 +557,14 @@ where } // Prep the relay message, which is a wrapped version of the original. - let encoded_len = serf_proto::Encodable::encoded_len_with_relay(&resp); + let encoded_len = serf_proto::Encodable::encoded_len_with_relay(&resp, &node); if encoded_len > self.inner.opts.query_response_size_limit { return Err(Error::relayed_response_too_large( self.inner.opts.query_response_size_limit, )); } - let raw = serf_proto::Encodable::encode_relay_to_bytes(&resp)?; + let raw = serf_proto::Encodable::encode_relay_to_bytes(&resp, &node)?; // Relay to a random set of peers. let relay_members = random_members(relay_factor as usize, members); diff --git a/serf-core/src/snapshot.rs b/serf-core/src/snapshot.rs index 8010630..5ad9e71 100644 --- a/serf-core/src/snapshot.rs +++ b/serf-core/src/snapshot.rs @@ -20,7 +20,7 @@ use memberlist_core::{ bytes::{BufMut, BytesMut}, proto::{Data, TinyVec}, tracing, - transport::{AddressResolver, Id, MaybeResolvedAddress, Node, Transport}, + transport::{Id, MaybeResolvedAddress, Node, Transport}, }; use rand::seq::SliceRandom; use serf_proto::UserEventMessage; @@ -372,10 +372,10 @@ impl SnapshotHandle { /// them to disk, and providing a recovery mechanism at start time. pub(crate) struct Snapshot where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { - alive_nodes: HashSet::ResolvedAddress>>, + alive_nodes: HashSet>, clock: LamportClock, fh: Option>, last_flush: Epoch, @@ -431,12 +431,12 @@ macro_rules! tee_stream_flush_event { impl Snapshot where - D: Delegate::ResolvedAddress>, + D: Delegate, T: Transport, { #[allow(clippy::type_complexity)] pub(crate) fn from_replay_result( - replay_result: ReplayResult::ResolvedAddress>, + replay_result: ReplayResult, min_compact_size: u64, rejoin_after_leave: bool, clock: LamportClock, @@ -677,10 +677,7 @@ where } /// Used to handle a single member event - fn process_member_event( - &mut self, - e: &MemberEvent::ResolvedAddress>, - ) { + fn process_member_event(&mut self, e: &MemberEvent) { match e.ty { MemberEventType::Join => { for m in e.members() { @@ -713,10 +710,7 @@ where } } - fn try_append( - &mut self, - l: SnapshotRecord<'_, T::Id, ::ResolvedAddress>, - ) { + fn try_append(&mut self, l: SnapshotRecord<'_, T::Id, T::ResolvedAddress>) { if let Err(e) = self.append_line(l) { tracing::error!(err = %e, "serf: failed to update snapshot"); if self.last_attempted_compaction.elapsed() > SNAPSHOT_ERROR_RECOVERY_INTERVAL { @@ -733,7 +727,7 @@ where fn append_line( &mut self, - l: SnapshotRecord<'_, T::Id, ::ResolvedAddress>, + l: SnapshotRecord<'_, T::Id, T::ResolvedAddress>, ) -> Result<(), SnapshotError> { #[cfg(feature = "metrics")] let start = crate::types::Epoch::now(); diff --git a/serf-proto/src/filter.rs b/serf-proto/src/filter.rs index 1da01b3..8b88a52 100644 --- a/serf-proto/src/filter.rs +++ b/serf-proto/src/filter.rs @@ -1,12 +1,9 @@ use memberlist_proto::{ - Data, DataRef, DecodeError, EncodeError, TinyVec, WireType, - utils::{merge, split}, + Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TinyVec, WireType, + utils::{merge, skip, split}, }; -pub use id_filter::*; pub use tag_filter::*; - -mod id_filter; mod tag_filter; /// The type of filter @@ -85,31 +82,15 @@ const FILTER_ID_TAG: u8 = 1; const FILTER_TAG_TAG: u8 = 2; /// The reference type to [`Filter`] -pub enum FilterRef<'a, I> { +#[derive(Clone, Copy, Debug)] +pub enum FilterRef<'a> { /// Filter by node ids - Id(IdDecoder<'a, I>), + Id(RepeatedDecoder<'a>), /// Filter by tag Tag(TagFilterRef<'a>), } -impl Clone for FilterRef<'_, I> { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for FilterRef<'_, I> {} - -impl core::fmt::Debug for FilterRef<'_, I> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Id(id) => f.debug_tuple("FilterRef::Id").field(id).finish(), - Self::Tag(t) => f.debug_tuple("FilterRef::Tag").field(t).finish(), - } - } -} - -impl<'a, I> DataRef<'a, Filter> for FilterRef<'a, I> +impl<'a, I> DataRef<'a, Filter> for FilterRef<'a> where I: Data, { @@ -123,26 +104,70 @@ where } let mut offset = 0; + let mut ids_offsets = None; + let mut num_ids = 0; + let mut f = None; - match buf[0] { - val if val == Filter::::id_byte() => { - offset += 1; - Ok((offset, Self::Id(IdDecoder::new(&buf[offset..])))) - } - val if val == Filter::::tag_byte() => { - offset += 1; - let (read, tag) = - >::decode_length_delimited(&buf[offset..])?; - offset += read; - Ok((offset, Self::Tag(tag))) - } - b => { - let (wire_type, tag) = split(b); - WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + while offset < buf_len { + match buf[offset] { + val if val == Filter::::id_byte() => { + let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = ids_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + ids_offsets = Some((offset - 1, offset + readed)); + } + num_ids += 1; + offset += readed; + } + val if val == Filter::::tag_byte() => { + if let Some(Self::Tag(_)) = f { + return Err(DecodeError::duplicate_field( + "Filter", + "tag", + FILTER_TAG_TAG, + )); + } + + if ids_offsets.is_some() { + return Err(DecodeError::duplicate_field("Filter", "id", FILTER_ID_TAG)); + } - Err(DecodeError::unknown_tag("Filter", tag)) + offset += 1; + let (read, tag) = + >::decode_length_delimited(&buf[offset..])?; + offset += read; + f = Some(FilterRef::Tag(tag)); + } + b => { + let (wire_type, _) = split(b); + let wt = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += 1; + offset += skip(wt, &buf[offset..])?; + } } } + + Ok(( + offset, + if let Some(tag) = f { + tag + } else if let Some((start, end)) = ids_offsets { + Self::Id( + RepeatedDecoder::new(FILTER_ID_TAG, WireType::LengthDelimited, buf) + .with_nums(num_ids) + .with_offsets(start, end), + ) + } else { + return Err(DecodeError::missing_field("Filter", "value")); + }, + )) } } @@ -163,7 +188,7 @@ impl Data for Filter where I: Data, { - type Ref<'a> = FilterRef<'a, I>; + type Ref<'a> = FilterRef<'a>; fn from_ref(val: Self::Ref<'_>) -> Result where @@ -171,6 +196,7 @@ where { match val { FilterRef::Id(decoder) => decoder + .iter::() .map(|res| res.and_then(I::from_ref)) .collect::>() .map(Self::Id), @@ -183,7 +209,7 @@ where + match self { Filter::Id(ids) => ids .iter() - .map(|id| id.encoded_len_with_length_delimited()) + .map(|id| 1 + id.encoded_len_with_length_delimited()) .sum::(), Filter::Tag(tag) => 1 + tag.encoded_len_with_length_delimited(), } @@ -202,12 +228,19 @@ where match self { Filter::Id(ids) => { - buf[offset] = Self::id_byte(); - offset += 1; - ids .iter() .try_fold(&mut offset, |offset, id| { + if *offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + + buf[*offset] = Self::id_byte(); + *offset += 1; + *offset += id.encode_length_delimited(&mut buf[*offset..])?; Ok(offset) diff --git a/serf-proto/src/filter/id_filter.rs b/serf-proto/src/filter/id_filter.rs deleted file mode 100644 index 0de00a2..0000000 --- a/serf-proto/src/filter/id_filter.rs +++ /dev/null @@ -1,63 +0,0 @@ -use memberlist_proto::{Data, DataRef, DecodeError}; - -/// The decoder for ids -pub struct IdDecoder<'a, I> { - src: &'a [u8], - len: usize, - offset: usize, - has_err: bool, - _phantom: std::marker::PhantomData, -} - -impl Clone for IdDecoder<'_, I> { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for IdDecoder<'_, I> {} - -impl core::fmt::Debug for IdDecoder<'_, I> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("IdDecoder") - .field("src", &self.src) - .field("offset", &self.offset) - .finish() - } -} - -impl<'a, I> IdDecoder<'a, I> { - pub(super) const fn new(src: &'a [u8]) -> Self { - Self { - src, - offset: 0, - len: src.len(), - has_err: false, - _phantom: std::marker::PhantomData, - } - } -} - -impl<'a, I> Iterator for IdDecoder<'a, I> -where - I: Data, -{ - type Item = Result, DecodeError>; - - fn next(&mut self) -> Option { - if self.has_err || self.offset >= self.len { - return None; - } - - Some( - as DataRef<'_, I>>::decode_length_delimited(&self.src[self.offset..]) - .inspect_err(|_| { - self.has_err = true; - }) - .map(|(read, value)| { - self.offset += read; - value - }), - ) - } -} diff --git a/serf-proto/src/key.rs b/serf-proto/src/key.rs index 211f2a4..023faff 100644 --- a/serf-proto/src/key.rs +++ b/serf-proto/src/key.rs @@ -161,13 +161,25 @@ const KEY_RESPONSE_PRIMARY_KEY_BYTE: u8 = #[viewit::viewit(getters(style = "ref", vis_all = "pub"), setters(skip), vis_all = "")] #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct KeyResponseMessageRef<'a> { - #[viewit(getter(const, attrs(doc = "Returns true/false if there were errors or not")))] + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns true/false if there were errors or not") + ))] result: bool, - #[viewit(getter(const, attrs(doc = "Returns the error messages or other information")))] + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns the error messages or other information") + ))] message: &'a str, #[viewit(getter(const, attrs(doc = "Returns a list of installed keys")))] keys: RepeatedDecoder<'a>, - #[viewit(getter(const, attrs(doc = "Returns the primary key")))] + #[viewit(getter( + const, + attrs(doc = "Returns the primary key"), + result(converter(fn = "Option::as_ref"), type = "Option<&SecretKey>"), + ))] primary_key: Option, } diff --git a/serf-proto/src/message.rs b/serf-proto/src/message.rs index 0757c9e..bb2d87f 100644 --- a/serf-proto/src/message.rs +++ b/serf-proto/src/message.rs @@ -1,4 +1,13 @@ -use memberlist_proto::{Data, EncodeError, WireType, bytes::Bytes, utils::merge}; +use memberlist_proto::{ + Data, DataRef, DecodeError, EncodeError, Node, WireType, + bytes::Bytes, + utils::{merge, skip, split}, +}; + +use crate::{ + ConflictResponseMessageRef, PushPullMessage, PushPullMessageRef, QueryMessageRef, + QueryResponseMessageRef, UserEventMessageRef, +}; use super::{ ConflictResponseMessage, ConflictResponseMessageBorrow, JoinMessage, LeaveMessage, @@ -6,7 +15,7 @@ use super::{ }; #[cfg(feature = "encryption")] -use super::{KeyRequestMessage, KeyResponseMessage}; +use super::{KeyRequestMessage, KeyResponseMessage, KeyResponseMessageRef}; const LEAVE_MESSAGE_TAG: u8 = 1; const JOIN_MESSAGE_TAG: u8 = 2; @@ -38,33 +47,44 @@ const KEY_RESPONSE_MESSAGE_BYTE: u8 = merge(WireType::LengthDelimited, KEY_RESPO /// The types of gossip messages Serf will send along /// memberlist. -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, derive_more::Display, derive_more::IsVariant)] #[repr(u8)] #[non_exhaustive] pub enum MessageType { /// Leave message + #[display("leave")] Leave, /// Join message + #[display("join")] Join, /// PushPull message + #[display("push_pull")] PushPull, /// UserEvent message + #[display("user_event")] UserEvent, /// Query message + #[display("query")] Query, /// QueryResponse message + #[display("query_response")] QueryResponse, /// ConflictResponse message + #[display("conflict_response")] ConflictResponse, /// Relay message + #[display("relay")] Relay, /// KeyRequest message #[cfg(feature = "encryption")] + #[display("key_request")] KeyRequest, /// KeyResponse message #[cfg(feature = "encryption")] + #[display("key_response")] KeyResponse, /// Unknown message type, used for forwards and backwards compatibility + #[display("unknown({_0})")] Unknown(u8), } @@ -141,19 +161,29 @@ macro_rules! bail { }; } +const RELAY_NODE_TAG: u8 = 1; +const RELAY_MSG_TAG: u8 = 2; + +const RELAY_NODE_BYTE: u8 = merge(WireType::LengthDelimited, RELAY_NODE_TAG); +const RELAY_MSG_BYTE: u8 = merge(WireType::LengthDelimited, RELAY_MSG_TAG); + /// A trait for encoding messages. pub trait Encodable { /// Encodes the message into a buffer. fn encode(&self, buf: &mut [u8]) -> Result; /// Encodes a relay message into a buffer. - fn encode_relay(&self, buf: &mut [u8]) -> Result { + fn encode_relay(&self, node: &Node, buf: &mut [u8]) -> Result + where + I: Data, + A: Data, + { let mut offset = 0; let buf_len = buf.len(); if offset >= buf_len { return Err(EncodeError::insufficient_buffer( - self.encoded_len_with_relay(), + self.encoded_len_with_relay(node), buf_len, )); } @@ -161,12 +191,36 @@ pub trait Encodable { buf[offset] = RELAY_MESSAGE_BYTE; offset += 1; + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len_with_relay(node), + buf_len, + )); + } + + buf[offset] = RELAY_NODE_BYTE; + offset += 1; + + offset += node + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len_with_relay(node), buf_len))?; + + if offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len_with_relay(node), + buf_len, + )); + } + + buf[offset] = RELAY_MSG_BYTE; + offset += 1; + offset += self .encode(&mut buf[offset..]) - .map_err(|e| e.update(self.encoded_len_with_relay(), buf_len))?; + .map_err(|e| e.update(self.encoded_len_with_relay(node), buf_len))?; #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len_with_relay()); + super::debug_assert_write_eq(offset, self.encoded_len_with_relay(node)); Ok(offset) } @@ -179,18 +233,26 @@ pub trait Encodable { } /// Encodes a relay message into a [`Bytes`]. - fn encode_relay_to_bytes(&self) -> Result { - let len = self.encoded_len_with_relay(); + fn encode_relay_to_bytes(&self, node: &Node) -> Result + where + I: Data, + A: Data, + { + let len = self.encoded_len_with_relay(node); let mut buf = vec![0; len]; - self.encode_relay(&mut buf).map(|_| Bytes::from(buf)) + self.encode_relay(node, &mut buf).map(|_| Bytes::from(buf)) } /// Returns the encoded length of the message. fn encoded_len(&self) -> usize; /// Returns the encoded length of the message with a relay tag. - fn encoded_len_with_relay(&self) -> usize { - 1 + self.encoded_len() + fn encoded_len_with_relay(&self, node: &Node) -> usize + where + I: Data, + A: Data, + { + 1 + node.encoded_len_with_length_delimited() + 1 + self.encoded_len() } } @@ -307,3 +369,319 @@ where 1 + self.encoded_len_in() } } + +/// A reference to a message. +pub enum MessageRef<'a, I, A> { + /// Leave message + Leave(LeaveMessage), + /// Join message + Join(JoinMessage), + /// PushPull message + PushPull(PushPullMessageRef<'a, I>), + /// UserEvent message + UserEvent(UserEventMessageRef<'a>), + /// Query message + Query(QueryMessageRef<'a, I, A>), + /// QueryResponse message + QueryResponse(QueryResponseMessageRef<'a, I, A>), + /// ConflictResponse message + ConflictResponse(ConflictResponseMessageRef<'a, I, A>), + /// Relay message + Relay { + /// The node + node: Node, + /// The offset of the payload to the original buffer + payload_offset: usize, + /// The relay message payload + payload: &'a [u8], + }, + #[cfg(feature = "encryption")] + /// KeyRequest message + KeyRequest(KeyRequestMessage), + #[cfg(feature = "encryption")] + /// KeyResponse message + KeyResponse(KeyResponseMessageRef<'a>), +} + +impl MessageRef<'_, I, A> { + /// Returns the message type. + #[inline] + pub fn ty(&self) -> MessageType { + match self { + Self::Leave(_) => MessageType::Leave, + Self::Join(_) => MessageType::Join, + Self::PushPull(_) => MessageType::PushPull, + Self::UserEvent(_) => MessageType::UserEvent, + Self::Query(_) => MessageType::Query, + Self::QueryResponse(_) => MessageType::QueryResponse, + Self::ConflictResponse(_) => MessageType::ConflictResponse, + Self::Relay { .. } => MessageType::Relay, + #[cfg(feature = "encryption")] + Self::KeyRequest(_) => MessageType::KeyRequest, + #[cfg(feature = "encryption")] + Self::KeyResponse(_) => MessageType::KeyResponse, + } + } +} + +/// Decode a message from a buffer. +pub fn decode_message( + buf: &[u8], +) -> Result, A::Ref<'_>>, DecodeError> +where + I: Data + Eq + core::hash::Hash, + A: Data, +{ + let mut offset = 0; + let buf_len = buf.len(); + let mut msg = None; + + while offset < buf_len { + match buf[offset] { + LEAVE_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + LEAVE_MESSAGE_TAG, + )); + } + offset += 1; + + let (len, val) = + > as DataRef<'_, LeaveMessage>>::decode_length_delimited( + &buf[offset..], + )?; + offset += len; + msg = Some(MessageRef::Leave(val)); + } + JOIN_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + JOIN_MESSAGE_TAG, + )); + } + + offset += 1; + let (len, val) = + > as DataRef<'_, JoinMessage>>::decode_length_delimited( + &buf[offset..], + )?; + offset += len; + msg = Some(MessageRef::Join(val)); + } + PUSH_PULL_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + PUSH_PULL_MESSAGE_TAG, + )); + } + + offset += 1; + let (len, val) = > as DataRef<'_, PushPullMessage>>::decode_length_delimited(&buf[offset..])?; + offset += len; + msg = Some(MessageRef::PushPull(val)); + } + USER_EVENT_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + USER_EVENT_MESSAGE_TAG, + )); + } + + offset += 1; + let (len, val) = + as DataRef<'_, UserEventMessage>>::decode_length_delimited( + &buf[offset..], + )?; + offset += len; + msg = Some(MessageRef::UserEvent(val)); + } + QUERY_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + QUERY_MESSAGE_TAG, + )); + } + offset += 1; + let (len, val) = , A::Ref<'_>> as DataRef< + '_, + QueryMessage, + >>::decode_length_delimited(&buf[offset..])?; + offset += len; + msg = Some(MessageRef::Query(val)); + } + QUERY_RESPONSE_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + QUERY_RESPONSE_MESSAGE_TAG, + )); + } + offset += 1; + let (len, val) = , A::Ref<'_>> as DataRef< + '_, + QueryResponseMessage, + >>::decode_length_delimited(&buf[offset..])?; + offset += len; + msg = Some(MessageRef::QueryResponse(val)); + } + CONFLICT_RESPONSE_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + CONFLICT_RESPONSE_MESSAGE_TAG, + )); + } + offset += 1; + let (len, val) = , A::Ref<'_>> as DataRef< + '_, + ConflictResponseMessage, + >>::decode_length_delimited(&buf[offset..])?; + offset += len; + msg = Some(MessageRef::ConflictResponse(val)); + } + RELAY_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + RELAY_MESSAGE_TAG, + )); + } + offset += 1; + let (readed, (node, payload)) = decode_relay::(&buf[offset..])?; + offset += readed; + msg = Some(MessageRef::Relay { + node, + payload, + payload_offset: offset - payload.len(), + }); + } + #[cfg(feature = "encryption")] + KEY_REQUEST_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + KEY_REQUEST_MESSAGE_TAG, + )); + } + + offset += 1; + let (len, val) = + >::decode_length_delimited( + &buf[offset..], + )?; + offset += len; + msg = Some(MessageRef::KeyRequest(val)); + } + #[cfg(feature = "encryption")] + KEY_RESPONSE_MESSAGE_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "Message", + "value", + KEY_RESPONSE_MESSAGE_TAG, + )); + } + + offset += 1; + let (len, val) = + as DataRef<'_, KeyResponseMessage>>::decode_length_delimited( + &buf[offset..], + )?; + offset += len; + msg = Some(MessageRef::KeyResponse(val)); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + let msg = msg.ok_or(DecodeError::missing_field("Message", "value"))?; + Ok(msg) +} + +fn decode_relay( + buf: &[u8], +) -> Result<(usize, (Node, A::Ref<'_>>, &[u8])), DecodeError> +where + I: Data, + A: Data, +{ + let mut offset = 0; + let buf_len = buf.len(); + + let mut node = None; + let mut msg = None; + + while offset < buf_len { + match buf[offset] { + RELAY_NODE_BYTE => { + if node.is_some() { + return Err(DecodeError::duplicate_field( + "RelayMessage", + "node", + RELAY_NODE_TAG, + )); + } + offset += 1; + + let (len, val) = + , A::Ref<'_>> as DataRef<'_, Node>>::decode_length_delimited( + &buf[offset..], + )?; + offset += len; + node = Some(val); + } + RELAY_MSG_BYTE => { + if msg.is_some() { + return Err(DecodeError::duplicate_field( + "RelayMessage", + "msg", + RELAY_MSG_TAG, + )); + } + offset += 1; + + // Skip length-delimited field by reading the length and skipping the payload + if buf[offset..].is_empty() { + return Err(DecodeError::buffer_underflow()); + } + + let (read, length) = ::decode(&buf[offset..])?; + offset += read; + + msg = Some(&buf[offset..offset + length as usize]); + offset += length as usize; + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + let node = node.ok_or(DecodeError::missing_field("RelayMessage", "node"))?; + + Ok((offset, (node, msg.unwrap_or_default()))) +} diff --git a/serf-proto/src/query.rs b/serf-proto/src/query.rs index c4a2d6d..8983943 100644 --- a/serf-proto/src/query.rs +++ b/serf-proto/src/query.rs @@ -168,13 +168,27 @@ pub struct QueryMessageRef<'a, I, A> { ))] timeout: Duration, /// Query nqme - #[viewit(getter(const, style = "ref", attrs(doc = "Returns the name of the query")))] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the name of the query")))] name: &'a str, /// Query payload - #[viewit(getter(const, style = "ref", attrs(doc = "Returns the payload")))] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the payload")))] payload: &'a [u8], } +impl QueryMessageRef<'_, I, A> { + /// Checks if the ack flag is set + #[inline] + pub fn ack(&self) -> bool { + self.flags.contains(QueryFlag::ACK) + } + + /// Checks if the no broadcast flag is set + #[inline] + pub fn no_broadcast(&self) -> bool { + self.flags.contains(QueryFlag::NO_BROADCAST) + } +} + impl<'a, I, A> DataRef<'a, QueryMessage> for QueryMessageRef<'a, I::Ref<'a>, A::Ref<'a>> where I: Data, diff --git a/serf-proto/src/query/response.rs b/serf-proto/src/query/response.rs index e7e5c1d..48bfbbf 100644 --- a/serf-proto/src/query/response.rs +++ b/serf-proto/src/query/response.rs @@ -81,19 +81,23 @@ impl QueryResponseMessage { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct QueryResponseMessageRef<'a, I, A> { /// Event lamport time - #[viewit(getter(const, attrs(doc = "Returns the lamport time for this message")))] + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for this message") + ))] ltime: LamportTime, /// query id - #[viewit(getter(const, attrs(doc = "Returns the query id")))] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the query id")))] id: u32, /// node #[viewit(getter(const, attrs(doc = "Returns the from node")))] from: Node, /// Used to provide various flags - #[viewit(getter(const, style = "ref", attrs(doc = "Returns the flags")))] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the flags")))] flags: QueryFlag, /// Optional response payload - #[viewit(getter(const, style = "ref", attrs(doc = "Returns the payload")))] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the payload")))] payload: &'a [u8], } diff --git a/serf-proto/src/tags.rs b/serf-proto/src/tags.rs index 607c492..7d52fb9 100644 --- a/serf-proto/src/tags.rs +++ b/serf-proto/src/tags.rs @@ -168,213 +168,3 @@ impl Data for Tags { .map_err(|e: EncodeError| e.update(self.encoded_len(), buf.len())) } } - -// #[derive(Debug)] -// struct Tag { -// key: SmolStr, -// value: SmolStr, -// } - -// impl Tag { -// fn split(self) -> (SmolStr, SmolStr) { -// (self.key, self.value) -// } -// } - -// impl Data for Tag { -// type Ref<'a> = TagRef<'a>; - -// fn from_ref(val: Self::Ref<'_>) -> Result -// where -// Self: Sized, -// { -// Ok(Self { -// key: SmolStr::new(val.key), -// value: SmolStr::new(val.value), -// }) -// } - -// fn encoded_len(&self) -> usize { -// TagRef::new(&self.key, &self.value).encoded_len() -// } - -// fn encode(&self, buf: &mut [u8]) -> Result { -// TagRef::new(&self.key, &self.value).encode(buf) -// } -// } - -// /// A reference to a (key, value) pair of a tag -// #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -// pub struct TagRef<'a> { -// key: &'a str, -// value: &'a str, -// } - -// impl<'a> DataRef<'a, Tag> for TagRef<'a> { -// fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> -// where -// Self: Sized, -// { -// let mut offset = 0; -// let buf_len = src.len(); - -// let mut key = None; -// let mut val = None; - -// while offset < buf_len { -// match src[offset] { -// Self::KEY_BYTE => { -// if key.is_some() { -// return Err(DecodeError::duplicate_field("Tag", "key", Self::KEY_TAG)); -// } -// offset += 1; - -// let (read, value) = -// <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; -// key = Some(value); -// offset += read; -// } -// Self::VALUE_BYTE => { -// if val.is_some() { -// return Err(DecodeError::duplicate_field( -// "Tag", -// "value", -// Self::VALUE_TAG, -// )); -// } -// offset += 1; - -// let (read, value) = -// <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; -// val = Some(value); -// offset += read; -// } -// other => { -// offset += 1; - -// let (wire_type, _) = split(other); -// let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; -// offset += skip(wire_type, &src[offset..])?; -// } -// } -// } - -// Ok(( -// offset, -// Self { -// key: key.unwrap_or(""), -// value: val.unwrap_or(""), -// }, -// )) -// } -// } - -// impl<'a> TagRef<'a> { -// const KEY_TAG: u8 = 1; -// const KEY_BYTE: u8 = merge(WireType::LengthDelimited, Self::KEY_TAG); -// const VALUE_TAG: u8 = 2; -// const VALUE_BYTE: u8 = merge(WireType::LengthDelimited, Self::VALUE_TAG); - -// fn new(key: &'a str, value: &'a str) -> Self { -// Self { key, value } -// } - -// fn encoded_len(&self) -> usize { -// let klen = self.key.len(); -// let vlen = self.value.len(); - -// let mut len = 0; -// if klen != 0 { -// len += 1 + (klen as u32).encoded_len(); -// } - -// if vlen != 0 { -// len += 1 + (vlen as u32).encoded_len(); -// } - -// len -// } - -// fn encoded_len_with_length_delimited(&self) -> usize { -// let len = self.encoded_len(); -// len + (len as u32).encoded_len() -// } - -// fn encode(&self, buf: &mut [u8]) -> Result { -// let buf_len = buf.len(); -// let mut offset = 0; - -// if buf_len <= offset { -// return Err(EncodeError::insufficient_buffer( -// self.encoded_len(), -// buf_len, -// )); -// } - -// let klen = self.key.len(); -// if klen != 0 { -// buf[offset] = Self::KEY_BYTE; -// offset += 1; - -// offset += (klen as u32) -// .encode(&mut buf[offset..]) -// .map_err(|e| e.update(self.encoded_len(), buf_len))?; -// if buf_len < offset + klen { -// return Err(EncodeError::insufficient_buffer( -// self.encoded_len(), -// buf_len, -// )); -// } -// buf[offset..offset + klen].copy_from_slice(self.key.as_bytes()); -// offset += klen; -// } - -// if buf_len <= offset { -// return Err(EncodeError::insufficient_buffer( -// self.encoded_len(), -// buf_len, -// )); -// } - -// let vlen = self.value.len(); -// if vlen != 0 { -// buf[offset] = Self::VALUE_BYTE; -// offset += 1; - -// offset += (vlen as u32) -// .encode(&mut buf[offset..]) -// .map_err(|e| e.update(self.encoded_len(), buf_len))?; -// if buf_len < offset + vlen { -// return Err(EncodeError::insufficient_buffer( -// self.encoded_len(), -// buf_len, -// )); -// } - -// buf[offset..offset + vlen].copy_from_slice(self.value.as_bytes()); -// offset += vlen; -// } - -// #[cfg(debug_assertions)] -// super::debug_assert_write_eq(offset, self.encoded_len()); - -// Ok(offset) -// } - -// fn encode_with_length_delimited(&self, buf: &mut [u8]) -> Result { -// let len = self.encoded_len(); -// let buf_len = buf.len(); -// if buf_len < len { -// return Err(EncodeError::insufficient_buffer(len, buf_len)); -// } - -// let mut offset = 0; -// offset += (len as u32).encode(&mut buf[offset..])?; -// offset += self.encode(&mut buf[offset..])?; - -// #[cfg(debug_assertions)] -// super::debug_assert_write_eq(offset, self.encoded_len_with_length_delimited()); - -// Ok(offset) -// } -// } diff --git a/serf-proto/src/user_event/message.rs b/serf-proto/src/user_event/message.rs index 7b192b5..182cec5 100644 --- a/serf-proto/src/user_event/message.rs +++ b/serf-proto/src/user_event/message.rs @@ -80,16 +80,24 @@ const PAYLOAD_BYTE: u8 = merge(WireType::LengthDelimited, PAYLOAD_TAG); #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct UserEventMessageRef<'a> { /// The lamport time - #[viewit(getter(const, attrs(doc = "Returns the lamport time for this message")))] + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for this message") + ))] ltime: LamportTime, /// The name of the event - #[viewit(getter(const, attrs(doc = "Returns the name of the event")))] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the name of the event")))] name: &'a str, /// The payload of the event - #[viewit(getter(const, attrs(doc = "Returns the payload of the event")))] + #[viewit(getter(const, style = "move", attrs(doc = "Returns the payload of the event")))] payload: &'a [u8], /// "Can Coalesce". - #[viewit(getter(const, attrs(doc = "Returns if this message can be coalesced")))] + #[viewit(getter( + const, + style = "move", + attrs(doc = "Returns if this message can be coalesced") + ))] cc: bool, } From 89f261d83e5135f6d45225f275c66ec475640069 Mon Sep 17 00:00:00 2001 From: al8n Date: Fri, 28 Feb 2025 20:59:57 +0800 Subject: [PATCH 08/39] WIP --- serf-core/src/serf/base.rs | 19 +-- serf-core/src/serf/base/tests.rs | 111 ++++++++---------- .../src/serf/base/tests/serf/delegate.rs | 13 +- serf-core/src/serf/base/tests/serf/event.rs | 23 ++-- serf-core/src/serf/base/tests/serf/join.rs | 2 +- .../src/serf/base/tests/serf/snapshot.rs | 2 +- serf-core/src/serf/delegate.rs | 8 +- serf-core/src/snapshot.rs | 4 +- serf-core/src/types/message.rs | 10 -- 9 files changed, 88 insertions(+), 104 deletions(-) delete mode 100644 serf-core/src/types/message.rs diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index b1870af..5c33734 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -12,7 +12,7 @@ use memberlist_core::{ transport::{MaybeResolvedAddress, Node}, }; use rand::{Rng, SeedableRng}; -use serf_proto::{MessageRef, QueryMessageRef, QueryResponseMessageRef, Tags, UserEventMessageRef}; +use serf_proto::{MessageRef, Tags, UserEventMessageRef}; use smol_str::SmolStr; use crate::{ @@ -33,17 +33,17 @@ use self::internal_query::SerfQueries; use super::*; -// /// Re-export the unit tests -// #[cfg(feature = "test")] -// #[cfg_attr(docsrs, doc(cfg(feature = "test")))] -// pub mod tests; +/// Re-export the unit tests +#[cfg(feature = "test")] +#[cfg_attr(docsrs, doc(cfg(feature = "test")))] +pub mod tests; impl Serf where D: Delegate, T: Transport, { - #[cfg(feature = "test")] + #[cfg(any(feature = "test", test))] pub(crate) async fn with_message_dropper( transport: T::Options, opts: Options, @@ -54,7 +54,7 @@ where None, transport, opts, - #[cfg(feature = "test")] + #[cfg(any(feature = "test", test))] Some(message_dropper), ) .await @@ -841,6 +841,7 @@ where true } + #[allow(clippy::too_many_arguments)] pub(crate) fn query_event( &self, ltime: LamportTime, @@ -933,7 +934,7 @@ where .await; // Process query locally - self.handle_query(Either::Right(q), ty).await; + self.handle_query(Either::Right(q), ty).await?; // Start broadcasting the event self @@ -1708,7 +1709,7 @@ where // Start an id resolution query let ty = InternalQueryEvent::Conflict(local_id.clone()); let resp = match self - .internal_query(SmolStr::new(ty.as_str()), payload.into(), None, ty) + .internal_query(SmolStr::new(ty.as_str()), payload, None, ty) .await { Ok(resp) => resp, diff --git a/serf-core/src/serf/base/tests.rs b/serf-core/src/serf/base/tests.rs index 024efe4..b1ea37c 100644 --- a/serf-core/src/serf/base/tests.rs +++ b/serf-core/src/serf/base/tests.rs @@ -6,23 +6,22 @@ use memberlist_core::{ bytes::Bytes, delegate::NodeDelegate, transport::MaybeResolvedAddress, - types::{OneOrMore, TinyVec}, + proto::{OneOrMore, TinyVec}, }; use serf_proto::{ - MessageType, Node, PushPullMessage, QueryFlag, QueryMessage, SerfMessage, UserEvent, - UserEventMessage, + MessageType, Node, PushPullMessage, QueryFlag, QueryMessage, UserEvent, + UserEventMessage, MessageRef, }; use smol_str::SmolStr; use crate::{ - delegate::, event::{CrateEvent, CrateEventType, MemberEvent, MemberEventType}, types::Epoch, }; use super::*; -pub(crate) mod serf; +// pub(crate) mod serf; fn test_config() -> Options { let mut opts = Options::new(); @@ -232,17 +231,15 @@ where event_tx.send(event.clone()).await.unwrap(); // Push a query - let query = s.query_event(QueryMessage { - ltime: 42.into(), - id: 1, - from: s.memberlist().advertise_node(), - filters: TinyVec::new(), - flags: QueryFlag::empty(), - relay_factor: 0, - timeout: Default::default(), - name: "foo".into(), - payload: Bytes::new(), - }); + let query = s.query_event( + 42.into(), + "foo".into(), + Bytes::new(), + Default::default(), + 1, + s.memberlist().advertise_node(), + 0, + ); event_tx.send(CrateEvent::from(query)).await.unwrap(); // Push a member event @@ -272,17 +269,15 @@ where let (event_tx, _handle) = SerfQueries::>::new(Some(tx), shutdown_rx); // Push a query - let query = s.query_event(QueryMessage { - ltime: 42.into(), - id: 1, - from: s.memberlist().advertise_node(), - filters: TinyVec::new(), - flags: QueryFlag::empty(), - relay_factor: 0, - timeout: Default::default(), - name: "ping".into(), - payload: Bytes::new(), - }); + let query = s.query_event( + 42.into(), + "ping".into(), + Bytes::new(), + Default::default(), + 1, + s.memberlist().advertise_node(), + 0, + ); event_tx .send(CrateEvent::from((InternalQueryEvent::Ping, query))) .await @@ -305,17 +300,15 @@ where let (event_tx, _handle) = SerfQueries::>::new(Some(tx), shutdown_rx); // Push a query - let query = s.query_event(QueryMessage { - ltime: 42.into(), - id: 1, - from: s.memberlist().advertise_node(), - filters: TinyVec::new(), - flags: QueryFlag::empty(), - relay_factor: 0, - timeout: Default::default(), - name: "conflict".into(), - payload: Bytes::new(), - }); + let query = s.query_event( + 42.into(), + "conflict".into(), + Bytes::new(), + Default::default(), + 1, + s.memberlist().advertise_node(), + 0, + ); let id = s.memberlist().local_id().clone(); event_tx .send(CrateEvent::from((InternalQueryEvent::Conflict(id), query))) @@ -345,17 +338,15 @@ pub async fn estimate_max_keys_in_list_key_response_factor( let size_limit = opts.query_response_size_limit() * 10; let opts = opts.with_query_response_size_limit(size_limit); let s = Serf::::new(transport_opts, opts).await.unwrap(); - let query = s.query_event(QueryMessage { - ltime: 0.into(), - id: 0, - from: s.memberlist().advertise_node(), - filters: TinyVec::new(), - flags: QueryFlag::empty(), - relay_factor: 0, - timeout: Default::default(), - name: Default::default(), - payload: Default::default(), - }); + let query = s.query_event( + 0.into(), + Default::default(), + Default::default(), + Default::default(), + 0, + s.memberlist().advertise_node(), + 0, + ); let mut resp = KeyResponseMessage::default(); for _ in 0..=(size_limit / 25) { @@ -395,21 +386,19 @@ where T: Transport, { use memberlist_core::proto::SecretKey; - use serf_proto::{Encodable, KeyResponseMessage}; + use serf_proto::KeyResponseMessage; let opts = opts.with_query_response_size_limit(1024); let s = Serf::::new(transport_opts, opts).await.unwrap(); - let query = s.query_event(QueryMessage { - ltime: 0.into(), - id: 0, - from: s.memberlist().advertise_node(), - filters: TinyVec::new(), - flags: QueryFlag::empty(), - relay_factor: 0, - timeout: Default::default(), - name: Default::default(), - payload: Default::default(), - }); + let query = s.query_event( + 0.into(), + Default::default(), + Default::default(), + Default::default(), + 0, + s.memberlist().advertise_node(), + 0, + ); let k = [0; 16]; let encoded_len = SecretKey::from(k).encoded_len(); diff --git a/serf-core/src/serf/base/tests/serf/delegate.rs b/serf-core/src/serf/base/tests/serf/delegate.rs index aa50fee..1d3202a 100644 --- a/serf-core/src/serf/base/tests/serf/delegate.rs +++ b/serf-core/src/serf/base/tests/serf/delegate.rs @@ -14,7 +14,7 @@ where .unwrap(); let meta = s.inner.memberlist.delegate().unwrap().node_meta(32).await; - let (_, tags) = decode_tags(&meta).unwrap(); + let (_, tags) = Tags::decode(&meta).unwrap(); assert_eq!(tags.get("role"), Some(&SmolStr::new("test"))); s.shutdown().await.unwrap(); @@ -78,16 +78,17 @@ where .await; // Verify - assert_eq!(buf[0], MessageType::PushPull as u8, "bad message type"); + assert_eq!(buf[0], MessageType::PushPull.into(), "bad message type"); // Attempt a decode - let (_, pp) = - decode_message(MessageType::PushPull, &buf[1..]) + let pp = + serf_proto::decode_message(&buf) .unwrap(); - let SerfMessage::PushPull(pp) = pp else { + let MessageRef::PushPull(pp) = pp else { panic!("bad message") }; + let pp = PushPullMessage::from_ref(pp).unwrap(); // Verify lamport clock assert_eq!(pp.ltime(), serfs[0].inner.clock.time(), "bad lamport clock"); @@ -150,7 +151,7 @@ where let buf = serf_proto::Encodable::encode_to_bytes(&pp).unwrap(); // Merge in fake state - d.merge_remote_state(buf, false).await; + d.merge_remote_state(&buf, false).await; // Verify lamport assert_eq!(s.inner.clock.time(), 42.into(), "bad lamport clock"); diff --git a/serf-core/src/serf/base/tests/serf/event.rs b/serf-core/src/serf/base/tests/serf/event.rs index d4e6ce9..d343248 100644 --- a/serf-core/src/serf/base/tests/serf/event.rs +++ b/serf-core/src/serf/base/tests/serf/event.rs @@ -18,9 +18,10 @@ where assert!( !s1 .handle_user_event( - UserEventMessage::default() + Either::Right(UserEventMessage::default() .with_ltime(1.into()) .with_name("old".into()) + ) ) .await, "should not rebroadcast" @@ -43,19 +44,19 @@ where .with_ltime(1.into()) .with_name("first".into()) .with_payload(Bytes::from_static(b"test")); - assert!(s1.handle_user_event(msg).await, "should rebroadcast"); + assert!(s1.handle_user_event(Either::Right(msg)).await, "should rebroadcast"); let msg = UserEventMessage::default() .with_ltime(1.into()) .with_name("first".into()) .with_payload(Bytes::from_static(b"newpayload")); - assert!(s1.handle_user_event(msg).await, "should rebroadcast"); + assert!(s1.handle_user_event(Either::Right(msg)).await, "should rebroadcast"); let msg = UserEventMessage::default() .with_ltime(1.into()) .with_name("second".into()) .with_payload(Bytes::from_static(b"other")); - assert!(s1.handle_user_event(msg).await, "should rebroadcast"); + assert!(s1.handle_user_event(Either::Right(msg)).await, "should rebroadcast"); test_user_events( event_rx.rx, @@ -679,7 +680,7 @@ pub async fn query_old_message( }, None ) - .await, + .await.unwrap(), "should not rebroadcast" ); @@ -712,11 +713,11 @@ pub async fn query_same_clock( }; assert!( - s1.handle_query(msg.clone(), None).await, + s1.handle_query(Either::Right(msg.clone()), None).await, "should rebroadcast" ); assert!( - !s1.handle_query(msg.clone(), None).await, + !s1.handle_query(Either::Right(msg.clone()), None).await, "should not rebroadcast" ); @@ -733,11 +734,11 @@ pub async fn query_same_clock( }; assert!( - s1.handle_query(msg.clone(), None).await, + s1.handle_query(Either::Right(msg.clone()), None).await, "should rebroadcast" ); assert!( - !s1.handle_query(msg.clone(), None).await, + !s1.handle_query(Either::Right(msg.clone()), None).await, "should not rebroadcast" ); @@ -753,11 +754,11 @@ pub async fn query_same_clock( payload: Bytes::from_static(b"other"), }; assert!( - s1.handle_query(msg.clone(), None).await, + s1.handle_query(Either::Right(msg.clone()), None).await, "should rebroadcast" ); assert!( - !s1.handle_query(msg.clone(), None).await, + !s1.handle_query(Either::Right(msg.clone()), None).await, "should not rebroadcast" ); diff --git a/serf-core/src/serf/base/tests/serf/join.rs b/serf-core/src/serf/base/tests/serf/join.rs index 1aca9ae..4dee47e 100644 --- a/serf-core/src/serf/base/tests/serf/join.rs +++ b/serf-core/src/serf/base/tests/serf/join.rs @@ -562,7 +562,7 @@ impl MergeDelegate for CancelMergeDelegat async fn notify_merge( &self, - _members: TinyVec>, + _members: Arc<[Member]>, ) -> Result<(), Self::Error> { self.invoked.store(true, Ordering::SeqCst); Err(CancelMergeError) diff --git a/serf-core/src/serf/base/tests/serf/snapshot.rs b/serf-core/src/serf/base/tests/serf/snapshot.rs index 9547d9d..d29318f 100644 --- a/serf-core/src/serf/base/tests/serf/snapshot.rs +++ b/serf-core/src/serf/base/tests/serf/snapshot.rs @@ -615,7 +615,7 @@ pub async fn serf_snapshot_recovery( async fn test_snapshoter_slow_disk_not_blocking_event_tx() { use memberlist_core::{ agnostic_lite::tokio::TokioRuntime, - transport::{Lpe, resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport}, + transport::{resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport}, }; use std::net::SocketAddr; diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index bb3d8b4..1690fd2 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -2,11 +2,11 @@ use crate::{ Serf, broadcast::SerfBroadcast, delegate::Delegate, - error::{Error, SerfDelegateError, SerfError}, + error::{SerfDelegateError, SerfError}, event::QueryMessageExt, types::{ DelegateVersion, JoinMessage, LamportTime, LeaveMessage, Member, MemberStatus, - MemberlistDelegateVersion, MemberlistProtocolVersion, MessageRef, MessageType, ProtocolVersion, + MessageRef, MessageType, ProtocolVersion, PushPullMessageBorrow, UserEventMessage, }, }; @@ -26,9 +26,9 @@ use memberlist_core::{ AliveDelegate, ConflictDelegate, Delegate as MemberlistDelegate, EventDelegate, MergeDelegate as MemberlistMergeDelegate, NodeDelegate, PingDelegate, }, - proto::{Data, Meta, NodeState, SmallVec, State, TinyVec}, + proto::{Data, Meta, NodeState, State}, tracing, - transport::{AddressResolver, Transport}, + transport::Transport, }; use serf_proto::{PushPullMessage, Tags}; diff --git a/serf-core/src/snapshot.rs b/serf-core/src/snapshot.rs index 5ad9e71..ee512cd 100644 --- a/serf-core/src/snapshot.rs +++ b/serf-core/src/snapshot.rs @@ -649,7 +649,9 @@ where } self.wait_tx.close(); - tee_handle.await; + if let Err(e) = tee_handle.await { + tracing::error!(target="serf", err=%e, "failed to wait for tee stream to exit"); + } tracing::debug!("serf: snapshotter stream exits"); } diff --git a/serf-core/src/types/message.rs b/serf-core/src/types/message.rs deleted file mode 100644 index 1711255..0000000 --- a/serf-core/src/types/message.rs +++ /dev/null @@ -1,10 +0,0 @@ -use memberlist_core::transport::Node; - -/// Used to store the end destination of a relayed message -#[derive(Debug, Clone, Eq, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "serde", serde(transparent))] -#[repr(transparent)] -pub(crate) struct RelayHeader { - pub(crate) dest: Node, -} From 24a0afea2a1efb0ab4661434c5dec169d6b5083c Mon Sep 17 00:00:00 2001 From: al8n Date: Sat, 1 Mar 2025 00:07:18 +0800 Subject: [PATCH 09/39] WIP --- Cargo.toml | 2 - serf-core/Cargo.toml | 15 +- serf-core/src/coalesce/member.rs | 2 +- serf-core/src/coalesce/user.rs | 2 +- serf-core/src/coordinate.rs | 166 +++++- serf-core/src/event.rs | 4 +- serf-core/src/event/crate_event.rs | 2 +- serf-core/src/key_manager.rs | 44 +- serf-core/src/serf/api.rs | 13 +- serf-core/src/serf/base.rs | 19 +- serf-core/src/serf/base/tests.rs | 19 +- serf-core/src/serf/base/tests/serf.rs | 2 +- .../src/serf/base/tests/serf/delegate.rs | 4 +- serf-core/src/serf/base/tests/serf/event.rs | 2 +- serf-core/src/serf/base/tests/serf/join.rs | 32 +- serf-core/src/serf/base/tests/serf/leave.rs | 16 +- serf-core/src/serf/delegate.rs | 13 +- serf-core/src/serf/internal_query.rs | 20 +- serf-core/src/serf/query.rs | 6 +- serf-core/src/snapshot.rs | 8 +- serf-core/src/types.rs | 60 ++- .../src/types}/arbitrary_impl.rs | 6 +- .../src => serf-core/src/types}/clock.rs | 8 +- .../src => serf-core/src/types}/conflict.rs | 2 +- .../src => serf-core/src/types}/filter.rs | 2 +- .../src/types}/filter/tag_filter.rs | 4 +- .../src => serf-core/src/types}/join.rs | 4 +- .../src => serf-core/src/types}/key.rs | 2 +- .../src => serf-core/src/types}/leave.rs | 2 +- serf-core/src/types/member.rs | 464 ++++++++++++++++- .../src => serf-core/src/types}/message.rs | 13 +- .../src => serf-core/src/types}/push_pull.rs | 8 +- .../src => serf-core/src/types}/query.rs | 6 +- .../src/types}/query/response.rs | 8 +- .../src => serf-core/src/types}/tags.rs | 8 +- .../src => serf-core/src/types}/user_event.rs | 4 +- .../src/types}/user_event/message.rs | 6 +- .../src/types}/user_event/user_events.rs | 8 +- .../src => serf-core/src/types}/version.rs | 2 +- serf-proto/Cargo.toml | 43 -- serf-proto/src/lib.rs | 75 --- serf-proto/src/member.rs | 474 ------------------ 42 files changed, 832 insertions(+), 768 deletions(-) rename {serf-proto/src => serf-core/src/types}/arbitrary_impl.rs (98%) rename {serf-proto/src => serf-core/src/types}/clock.rs (93%) rename {serf-proto/src => serf-core/src/types}/conflict.rs (99%) rename {serf-proto/src => serf-core/src/types}/filter.rs (99%) rename {serf-proto/src => serf-core/src/types}/filter/tag_filter.rs (98%) rename {serf-proto/src => serf-core/src/types}/join.rs (97%) rename {serf-proto/src => serf-core/src/types}/key.rs (99%) rename {serf-proto/src => serf-core/src/types}/leave.rs (99%) rename {serf-proto/src => serf-core/src/types}/message.rs (98%) rename {serf-proto/src => serf-core/src/types}/push_pull.rs (97%) rename {serf-proto/src => serf-core/src/types}/query.rs (98%) rename {serf-proto/src => serf-core/src/types}/query/response.rs (98%) rename {serf-proto/src => serf-core/src/types}/tags.rs (94%) rename {serf-proto/src => serf-core/src/types}/user_event.rs (96%) rename {serf-proto/src => serf-core/src/types}/user_event/message.rs (97%) rename {serf-proto/src => serf-core/src/types}/user_event/user_events.rs (95%) rename {serf-proto/src => serf-core/src/types}/version.rs (97%) delete mode 100644 serf-proto/Cargo.toml delete mode 100644 serf-proto/src/lib.rs delete mode 100644 serf-proto/src/member.rs diff --git a/Cargo.toml b/Cargo.toml index 37e08ac..f2eb25a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ members = [ "serf", "serf-core", - "serf-proto" ] resolver = "3" @@ -40,4 +39,3 @@ memberlist-core = { version = "0.6", path = "../memberlist/memberlist-core", def memberlist = { version = "0.6", path = "../memberlist/memberlist", default-features = false } serf-core = { path = "serf-core", version = "0.3.0", default-features = false } -serf-proto = { path = "serf-proto", version = "0.1.0", default-features = false } diff --git a/serf-core/Cargo.toml b/serf-core/Cargo.toml index f6f668c..cc6dbf8 100644 --- a/serf-core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -13,14 +13,14 @@ categories.workspace = true [features] default = ["metrics"] -metrics = ["memberlist-core/metrics", "dep:metrics", "serf-proto/metrics"] -encryption = ["memberlist-core/encryption", "serf-proto/encryption", "base64", "serde"] +metrics = ["memberlist-core/metrics", "dep:metrics"] +encryption = ["memberlist-core/encryption", "base64", "serde"] serde = [ "dep:serde", "dep:humantime-serde", + "bitflags/serde", "memberlist-core/serde", - "serf-proto/serde", "smol_str/serde", "smallvec/serde", "indexmap/serde", @@ -28,12 +28,17 @@ serde = [ test = ["memberlist-core/test", "paste", "tracing-subscriber", "tempfile"] +arbitrary = ["dep:arbitrary", "memberlist-core/arbitrary", "smol_str/arbitrary"] +quickcheck = ["dep:quickcheck", "memberlist-core/quickcheck"] + + [dependencies] auto_impl = "1" atomic_refcell = "0.1" arc-swap = "1" async-lock = "3" async-channel = "2" +bitflags = "2" byteorder.workspace = true crossbeam-queue = "0.3" derive_more.workspace = true @@ -51,7 +56,6 @@ smallvec.workspace = true thiserror.workspace = true viewit.workspace = true memberlist-core.workspace = true -serf-proto.workspace = true metrics = { version = "0.24", optional = true } @@ -61,6 +65,9 @@ serde_json = "1" base64 = { version = "0.22", optional = true } +arbitrary = { version = "1", optional = true, default-features = false, features = ["derive"] } +quickcheck = { version = "1", optional = true, default-features = false } + # test features paste = { version = "1", optional = true } tracing-subscriber = { version = "0.3", optional = true, features = [ diff --git a/serf-core/src/coalesce/member.rs b/serf-core/src/coalesce/member.rs index 6fe9446..db325b5 100644 --- a/serf-core/src/coalesce/member.rs +++ b/serf-core/src/coalesce/member.rs @@ -122,7 +122,7 @@ where // agnostic_lite::{RuntimeLite, tokio::TokioRuntime}, // transport::resolver::socket_addr::SocketAddrResolver, // }; -// use serf_proto::{MemberStatus, UserEventMessage}; +// use crate::types::{MemberStatus, UserEventMessage}; // use smol_str::SmolStr; // use crate::{ diff --git a/serf-core/src/coalesce/user.rs b/serf-core/src/coalesce/user.rs index 7a31dc8..712d505 100644 --- a/serf-core/src/coalesce/user.rs +++ b/serf-core/src/coalesce/user.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; +use crate::types::UserEventMessage; use indexmap::IndexMap; use memberlist_core::proto::TinyVec; -use serf_proto::UserEventMessage; use smol_str::SmolStr; use crate::types::LamportTime; diff --git a/serf-core/src/coordinate.rs b/serf-core/src/coordinate.rs index 165bee8..7a285df 100644 --- a/serf-core/src/coordinate.rs +++ b/serf-core/src/coordinate.rs @@ -6,7 +6,10 @@ use std::{ use memberlist_core::{ CheapClone, - proto::{Data, DataRef, DecodeError, EncodeError, RepeatedDecoder}, + proto::{ + Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, WireType, + utils::{merge, skip, split}, + }, }; use parking_lot::RwLock; use rand::Rng; @@ -726,6 +729,15 @@ fn rand_f64() -> f64 { } } +const PORTION_TAG: u8 = 1; +const ERROR_TAG: u8 = 2; +const ADJUSTMENT_TAG: u8 = 3; +const HEIGHT_TAG: u8 = 4; +const PORTION_BYTE: u8 = merge(WireType::LengthDelimited, PORTION_TAG); +const ERROR_BYTE: u8 = merge(WireType::Fixed64, ERROR_TAG); +const ADJUSTMENT_BYTE: u8 = merge(WireType::Fixed64, ADJUSTMENT_TAG); +const HEIGHT_BYTE: u8 = merge(WireType::Fixed64, HEIGHT_TAG); + /// The reference type to [`Coordinate`]. #[derive(Copy, Clone, Debug, PartialEq)] pub struct CoordinateRef<'a> { @@ -740,7 +752,98 @@ impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { where Self: Sized, { - todo!() + let mut offset = 0; + let buf_len = buf.len(); + + let mut portion_offsets = None; + let mut num_portions = 0; + let mut error = None; + let mut adjustment = None; + let mut height = None; + + while offset < buf_len { + match buf[offset] { + PORTION_TAG => { + let readed = skip(WireType::Fixed64, &buf[offset..])?; + if let Some((ref mut fnso, ref mut lnso)) = portion_offsets { + if *fnso > offset { + *fnso = offset - 1; + } + + if *lnso < offset + readed { + *lnso = offset + readed; + } + } else { + portion_offsets = Some((offset - 1, offset + readed)); + } + num_portions += 1; + offset += readed; + } + ERROR_TAG => { + if error.is_some() { + return Err(DecodeError::duplicate_field( + "Coordinate", + "error", + ERROR_TAG, + )); + } + + let (len, val) = ::decode(&buf[offset..])?; + offset += len; + error = Some(val); + } + ADJUSTMENT_TAG => { + if adjustment.is_some() { + return Err(DecodeError::duplicate_field( + "Coordinate", + "adjustment", + ADJUSTMENT_TAG, + )); + } + + let (len, val) = ::decode(&buf[offset..])?; + offset += len; + adjustment = Some(val); + } + HEIGHT_TAG => { + if height.is_some() { + return Err(DecodeError::duplicate_field( + "Coordinate", + "height", + HEIGHT_TAG, + )); + } + + let (len, val) = ::decode(&buf[offset..])?; + offset += len; + height = Some(val); + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + portion: if let Some((start, end)) = portion_offsets { + RepeatedDecoder::new(PORTION_TAG, WireType::Fixed64, buf) + .with_nums(num_portions) + .with_offsets(start, end) + } else { + RepeatedDecoder::new(PORTION_TAG, WireType::Fixed64, buf) + }, + error: error.ok_or_else(|| DecodeError::missing_field("Coordinate", "error"))?, + adjustment: adjustment + .ok_or_else(|| DecodeError::missing_field("Coordinate", "adjustment"))?, + height: height.ok_or_else(|| DecodeError::missing_field("Coordinate", "height"))?, + }, + )) } } @@ -751,16 +854,67 @@ impl Data for Coordinate { where Self: Sized, { - // Ok(val) - todo!() + val + .portion + .iter::() + .collect::, _>>() + .map(|portion| Self { + portion, + error: val.error, + adjustment: val.adjustment, + height: val.height, + }) } fn encoded_len(&self) -> usize { - todo!() + self + .portion + .iter() + .fold(0, |acc, x| acc + 1 + x.encoded_len_with_length_delimited()) + + 1 + + self.error.encoded_len() + + 1 + + self.adjustment.encoded_len() + + 1 + + self.height.encoded_len() } fn encode(&self, buf: &mut [u8]) -> Result { - todo!() + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + } + }; + } + + let mut offset = 0; + let buf_len = buf.len(); + for x in self.portion.iter() { + bail!(self(offset, buf_len)); + buf[offset] = PORTION_BYTE; + offset += 1; + offset += x + .encode(&mut buf[offset..]) + .map_err(|e| e.update(self.encoded_len(), buf_len))?; + } + + bail!(self(offset, buf_len)); + buf[offset] = ERROR_BYTE; + offset += 1; + offset += self.error.encode(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = ADJUSTMENT_BYTE; + offset += 1; + offset += self.adjustment.encode(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = HEIGHT_BYTE; + offset += 1; + offset += self.height.encode(&mut buf[offset..])?; + + Ok(offset) } } diff --git a/serf-core/src/event.rs b/serf-core/src/event.rs index 962d250..78fc071 100644 --- a/serf-core/src/event.rs +++ b/serf-core/src/event.rs @@ -9,11 +9,11 @@ mod crate_event; use async_channel::Sender; pub use async_channel::{RecvError, TryRecvError}; +use crate::types::{LamportTime, Member, Node, QueryFlag, QueryResponseMessage, UserEventMessage}; use async_lock::Mutex; pub(crate) use crate_event::*; use futures::Stream; use memberlist_core::{CheapClone, bytes::Bytes, proto::TinyVec, transport::Transport}; -use serf_proto::{LamportTime, Member, Node, QueryFlag, QueryResponseMessage, UserEventMessage}; use smol_str::SmolStr; pub(crate) struct QueryContext @@ -91,7 +91,7 @@ where flags: QueryFlag::empty(), payload: msg, }; - let buf = serf_proto::Encodable::encode_to_bytes(&resp)?; + let buf = crate::types::Encodable::encode_to_bytes(&resp)?; self .respond_with_message_and_response(respond_to, relay_factor, buf, resp) .await diff --git a/serf-core/src/event/crate_event.rs b/serf-core/src/event/crate_event.rs index bf9444f..df45049 100644 --- a/serf-core/src/event/crate_event.rs +++ b/serf-core/src/event/crate_event.rs @@ -1,5 +1,5 @@ +use crate::types::{QueryMessage, QueryMessageRef}; use memberlist_core::proto::{Data, DecodeError}; -use serf_proto::{QueryMessage, QueryMessageRef}; use super::*; diff --git a/serf-core/src/key_manager.rs b/serf-core/src/key_manager.rs index 4c2135d..7bd4689 100644 --- a/serf-core/src/key_manager.rs +++ b/serf-core/src/key_manager.rs @@ -1,10 +1,10 @@ use std::{collections::HashMap, sync::OnceLock}; +use crate::types::MessageRef; use async_channel::Receiver; use async_lock::RwLock; use futures::StreamExt; use memberlist_core::{CheapClone, proto::SecretKey, tracing, transport::Transport}; -use serf_proto::MessageRef; use smol_str::{SmolStr, format_smolstr}; use crate::event::{ @@ -181,7 +181,7 @@ where event: InternalQueryEvent, ) -> Result, Error> { let kr = KeyRequestMessage { key }; - let buf = serf_proto::Encodable::encode_to_bytes(&kr)?; + let buf = crate::types::Encodable::encode_to_bytes(&kr)?; let serf = self.serf.get().unwrap(); let mut q_param = serf.default_query_param().await; @@ -248,14 +248,27 @@ where continue; } - let node_response = match serf_proto::decode_message::(&r.payload) - { - Ok(msg) => match msg { - MessageRef::KeyResponse(kr) => kr, - msg => { + let node_response = + match crate::types::decode_message::(&r.payload) { + Ok(msg) => match msg { + MessageRef::KeyResponse(kr) => kr, + msg => { + resp.messages.insert( + r.from.id().cheap_clone(), + format_smolstr!("Invalid key query response type: {}", msg.ty()), + ); + resp.num_err += 1; + + if resp.num_resp == resp.num_nodes { + return resp; + } + continue; + } + }, + Err(e) => { resp.messages.insert( r.from.id().cheap_clone(), - format_smolstr!("Invalid key query response type: {}", msg.ty()), + SmolStr::new(format!("Failed to decode key query response: {:?}", e)), ); resp.num_err += 1; @@ -264,20 +277,7 @@ where } continue; } - }, - Err(e) => { - resp.messages.insert( - r.from.id().cheap_clone(), - SmolStr::new(format!("Failed to decode key query response: {:?}", e)), - ); - resp.num_err += 1; - - if resp.num_resp == resp.num_nodes { - return resp; - } - continue; - } - }; + }; if !node_response.result() { resp.messages.insert( diff --git a/serf-core/src/serf/api.rs b/serf-core/src/serf/api.rs index f61c7e7..f8b4924 100644 --- a/serf-core/src/serf/api.rs +++ b/serf-core/src/serf/api.rs @@ -4,9 +4,8 @@ use futures::{FutureExt, StreamExt}; use memberlist_core::{ CheapClone, bytes::Bytes, - proto::{Data, Meta, OneOrMore, SmallVec}, + proto::{Data, MaybeResolvedAddress, Meta, Node, OneOrMore, SmallVec}, tracing, - transport::{MaybeResolvedAddress, Node}, }; use smol_str::SmolStr; @@ -269,7 +268,7 @@ where }; // Start broadcasting the event - let len = serf_proto::Encodable::encoded_len(&msg); + let len = crate::types::Encodable::encoded_len(&msg); // Check the size after encoding to be sure again that // we're not attempting to send over the specified size limit. @@ -281,7 +280,7 @@ where return Err(Error::raw_user_event_too_large(len)); } - let raw = serf_proto::Encodable::encode_to_bytes(&msg)?; + let raw = crate::types::Encodable::encode_to_bytes(&msg)?; self.inner.event_clock.increment(); @@ -319,7 +318,7 @@ where /// user messages sent prior to the join will be ignored. pub async fn join( &self, - node: Node>, + node: Node>, ignore_old: bool, ) -> Result, Error> { // Do a quick state check @@ -367,7 +366,7 @@ where /// user messages sent prior to the join will be ignored. pub async fn join_many( &self, - existing: impl Iterator>>, + existing: impl Iterator>>, ignore_old: bool, ) -> Result< SmallVec>, @@ -460,7 +459,7 @@ where // other node alive. if self.has_alive_members().await { let (notify_tx, notify_rx) = async_channel::bounded(1); - let msg = serf_proto::Encodable::encode_to_bytes(&msg)?; + let msg = crate::types::Encodable::encode_to_bytes(&msg)?; self.broadcast(msg, Some(notify_tx)).await?; futures::select! { diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index 5c33734..e4797e5 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use crate::types::{MessageRef, Tags, UserEventMessageRef}; use either::Either; use futures::{FutureExt, StreamExt}; use memberlist_core::{ @@ -7,12 +8,10 @@ use memberlist_core::{ agnostic_lite::AfterHandle, bytes::Bytes, delegate::EventDelegate, - proto::{Data, Meta, NodeState, OneOrMore, TinyVec}, + proto::{Data, MaybeResolvedAddress, Meta, Node, NodeState, OneOrMore, TinyVec}, tracing, - transport::{MaybeResolvedAddress, Node}, }; use rand::{Rng, SeedableRng}; -use serf_proto::{MessageRef, Tags, UserEventMessageRef}; use smol_str::SmolStr; use crate::{ @@ -387,7 +386,7 @@ where // Process update locally self.handle_node_join_intent(&msg).await; - let msg = serf_proto::Encodable::encode_to_bytes(&msg)?; + let msg = crate::types::Encodable::encode_to_bytes(&msg)?; // Start broadcasting the update if let Err(e) = self.broadcast(msg, None).await { tracing::warn!(err=%e, "serf: failed to broadcast join intent"); @@ -472,7 +471,7 @@ where return Ok(()); } - let msg = serf_proto::Encodable::encode_to_bytes(&msg)?; + let msg = crate::types::Encodable::encode_to_bytes(&msg)?; // Broadcast the remove let (ntx, nrx) = async_channel::bounded(1); self.broadcast(msg, Some(ntx)).await?; @@ -918,14 +917,14 @@ where }; // Encode the query - let len = serf_proto::Encodable::encoded_len(&q); + let len = crate::types::Encodable::encoded_len(&q); // Check the size if len > self.inner.opts.query_size_limit { return Err(Error::query_too_large(len)); } - let raw = serf_proto::Encodable::encode_to_bytes(&q)?; + let raw = crate::types::Encodable::encode_to_bytes(&q)?; // Register QueryResponse to track acks and responses let resp = QueryResponse::from_query(&q, self.inner.memberlist.num_online_members().await); @@ -1087,7 +1086,7 @@ where payload: Bytes::new(), }; - match serf_proto::Encodable::encode_to_bytes(&ack) { + match crate::types::Encodable::encode_to_bytes(&ack) { Ok(raw) => { let (name, payload, from, relay_factor) = match q { Either::Left(q) => ( @@ -1726,7 +1725,7 @@ where // Gather responses let resp_rx = resp.response_rx(); while let Ok(r) = resp_rx.recv().await { - let res = serf_proto::decode_message::(&r.payload); + let res = crate::types::decode_message::(&r.payload); match res { Ok(msg) => { match msg { @@ -1786,7 +1785,7 @@ where pub(crate) fn handle_rejoin( memberlist: Memberlist>, - alive_nodes: TinyVec>>, + alive_nodes: TinyVec>>, ) { ::spawn_detach(async move { for prev in alive_nodes { diff --git a/serf-core/src/serf/base/tests.rs b/serf-core/src/serf/base/tests.rs index b1ea37c..5d7c71c 100644 --- a/serf-core/src/serf/base/tests.rs +++ b/serf-core/src/serf/base/tests.rs @@ -1,16 +1,15 @@ use std::time::Duration; +use crate::types::{ + MessageRef, MessageType, Node, PushPullMessage, QueryFlag, QueryMessage, UserEvent, + UserEventMessage, +}; use async_channel::Receiver; use memberlist_core::{ agnostic_lite::RuntimeLite, bytes::Bytes, delegate::NodeDelegate, - transport::MaybeResolvedAddress, - proto::{OneOrMore, TinyVec}, -}; -use serf_proto::{ - MessageType, Node, PushPullMessage, QueryFlag, QueryMessage, UserEvent, - UserEventMessage, MessageRef, + proto::{MaybeResolvedAddress, OneOrMore, TinyVec}, }; use smol_str::SmolStr; @@ -332,8 +331,8 @@ pub async fn estimate_max_keys_in_list_key_response_factor( ) where T: Transport, { + use crate::types::KeyResponseMessage; use memberlist_core::proto::SecretKey; - use serf_proto::KeyResponseMessage; let size_limit = opts.query_response_size_limit() * 10; let opts = opts.with_query_response_size_limit(size_limit); @@ -355,10 +354,10 @@ pub async fn estimate_max_keys_in_list_key_response_factor( let mut found = 0; for i in (0..=resp.keys.len()).rev() { - let dst = serf_proto::Encodable::encode_to_bytes(&resp).unwrap(); + let dst = crate::types::Encodable::encode_to_bytes(&resp).unwrap(); let qresp = query.create_response(dst); - let dst = serf_proto::Encodable::encode_to_bytes(&qresp).unwrap(); + let dst = crate::types::Encodable::encode_to_bytes(&qresp).unwrap(); if query.check_response_size(dst.len()).is_err() { resp.keys.truncate(i); continue; @@ -385,8 +384,8 @@ pub async fn key_list_key_response_with_correct_size(transport_opts: T::Optio where T: Transport, { + use crate::types::KeyResponseMessage; use memberlist_core::proto::SecretKey; - use serf_proto::KeyResponseMessage; let opts = opts.with_query_response_size_limit(1024); let s = Serf::::new(transport_opts, opts).await.unwrap(); diff --git a/serf-core/src/serf/base/tests/serf.rs b/serf-core/src/serf/base/tests/serf.rs index 6f5366d..23769f9 100644 --- a/serf-core/src/serf/base/tests/serf.rs +++ b/serf-core/src/serf/base/tests/serf.rs @@ -2,7 +2,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use memberlist_core::{tests::AnyError, transport::Id}; -use serf_proto::{Member, MemberStatus, Tags}; +use crate::types::{Member, MemberStatus, Tags}; use crate::{event::EventProducer, types::MemberState}; diff --git a/serf-core/src/serf/base/tests/serf/delegate.rs b/serf-core/src/serf/base/tests/serf/delegate.rs index 1d3202a..cdf779a 100644 --- a/serf-core/src/serf/base/tests/serf/delegate.rs +++ b/serf-core/src/serf/base/tests/serf/delegate.rs @@ -82,7 +82,7 @@ where // Attempt a decode let pp = - serf_proto::decode_message(&buf) + crate::types::decode_message(&buf) .unwrap(); let MessageRef::PushPull(pp) = pp else { @@ -148,7 +148,7 @@ where query_ltime: 100.into(), }; - let buf = serf_proto::Encodable::encode_to_bytes(&pp).unwrap(); + let buf = crate::types::Encodable::encode_to_bytes(&pp).unwrap(); // Merge in fake state d.merge_remote_state(&buf, false).await; diff --git a/serf-core/src/serf/base/tests/serf/event.rs b/serf-core/src/serf/base/tests/serf/event.rs index d343248..d382e50 100644 --- a/serf-core/src/serf/base/tests/serf/event.rs +++ b/serf-core/src/serf/base/tests/serf/event.rs @@ -1,4 +1,4 @@ -use serf_proto::{Filter, FilterType}; +use crate::types::{Filter, FilterType}; use super::*; diff --git a/serf-core/src/serf/base/tests/serf/join.rs b/serf-core/src/serf/base/tests/serf/join.rs index 4dee47e..acbed21 100644 --- a/serf-core/src/serf/base/tests/serf/join.rs +++ b/serf-core/src/serf/base/tests/serf/join.rs @@ -53,10 +53,10 @@ pub async fn join_intent_old_message( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Alive, - memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, - protocol_version: serf_proto::ProtocolVersion::V1, - delegate_version: serf_proto::DelegateVersion::V1, + memberlist_protocol_version: crate::types::MemberlistProtocolVersion::V1, + memberlist_delegate_version: crate::types::MemberlistDelegateVersion::V1, + protocol_version: crate::types::ProtocolVersion::V1, + delegate_version: crate::types::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, @@ -104,10 +104,10 @@ pub async fn join_intent_newer( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Alive, - memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, - protocol_version: serf_proto::ProtocolVersion::V1, - delegate_version: serf_proto::DelegateVersion::V1, + memberlist_protocol_version: crate::types::MemberlistProtocolVersion::V1, + memberlist_delegate_version: crate::types::MemberlistDelegateVersion::V1, + protocol_version: crate::types::ProtocolVersion::V1, + delegate_version: crate::types::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, @@ -156,10 +156,10 @@ pub async fn join_intent_reset_leaving( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Leaving, - memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, - protocol_version: serf_proto::ProtocolVersion::V1, - delegate_version: serf_proto::DelegateVersion::V1, + memberlist_protocol_version: crate::types::MemberlistProtocolVersion::V1, + memberlist_delegate_version: crate::types::MemberlistDelegateVersion::V1, + protocol_version: crate::types::ProtocolVersion::V1, + delegate_version: crate::types::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, @@ -294,8 +294,8 @@ pub async fn join_pending_intent( addr, meta: Meta::empty(), state: memberlist_core::proto::State::Alive, - protocol_version: serf_proto::MemberlistProtocolVersion::V1, - delegate_version: serf_proto::MemberlistDelegateVersion::V1, + protocol_version: crate::types::MemberlistProtocolVersion::V1, + delegate_version: crate::types::MemberlistDelegateVersion::V1, })) .await; @@ -341,8 +341,8 @@ pub async fn join_pending_intents( addr, meta: Meta::empty(), state: memberlist_core::proto::State::Alive, - protocol_version: serf_proto::MemberlistProtocolVersion::V1, - delegate_version: serf_proto::MemberlistDelegateVersion::V1, + protocol_version: crate::types::MemberlistProtocolVersion::V1, + delegate_version: crate::types::MemberlistDelegateVersion::V1, })) .await; diff --git a/serf-core/src/serf/base/tests/serf/leave.rs b/serf-core/src/serf/base/tests/serf/leave.rs index 62d4208..3b772df 100644 --- a/serf-core/src/serf/base/tests/serf/leave.rs +++ b/serf-core/src/serf/base/tests/serf/leave.rs @@ -50,10 +50,10 @@ pub async fn leave_intent_old_message( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Alive, - memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, - protocol_version: serf_proto::ProtocolVersion::V1, - delegate_version: serf_proto::DelegateVersion::V1, + memberlist_protocol_version: crate::types::MemberlistProtocolVersion::V1, + memberlist_delegate_version: crate::types::MemberlistDelegateVersion::V1, + protocol_version: crate::types::ProtocolVersion::V1, + delegate_version: crate::types::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, @@ -101,10 +101,10 @@ pub async fn leave_intent_newer( node: Node::new("test".into(), addr), tags: Arc::new(Default::default()), status: MemberStatus::Alive, - memberlist_protocol_version: serf_proto::MemberlistProtocolVersion::V1, - memberlist_delegate_version: serf_proto::MemberlistDelegateVersion::V1, - protocol_version: serf_proto::ProtocolVersion::V1, - delegate_version: serf_proto::DelegateVersion::V1, + memberlist_protocol_version: crate::types::MemberlistProtocolVersion::V1, + memberlist_delegate_version: crate::types::MemberlistDelegateVersion::V1, + protocol_version: crate::types::ProtocolVersion::V1, + delegate_version: crate::types::DelegateVersion::V1, }, status_time: 12.into(), leave_time: None, diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index 1690fd2..7246cd7 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -5,9 +5,8 @@ use crate::{ error::{SerfDelegateError, SerfError}, event::QueryMessageExt, types::{ - DelegateVersion, JoinMessage, LamportTime, LeaveMessage, Member, MemberStatus, - MessageRef, MessageType, ProtocolVersion, - PushPullMessageBorrow, UserEventMessage, + DelegateVersion, JoinMessage, LamportTime, LeaveMessage, Member, MemberStatus, MessageRef, + MessageType, ProtocolVersion, PushPullMessageBorrow, UserEventMessage, }, }; @@ -16,6 +15,7 @@ use std::{ sync::{Arc, OnceLock, atomic::Ordering}, }; +use crate::types::{PushPullMessage, Tags}; use arc_swap::ArcSwap; use either::Either; use indexmap::IndexSet; @@ -30,7 +30,6 @@ use memberlist_core::{ tracing, transport::Transport, }; -use serf_proto::{PushPullMessage, Tags}; // PingVersion is an internal version for the ping message, above the normal // versioning we get from the protocol version. This enables small updates @@ -184,7 +183,7 @@ where let mut rebroadcast = false; let mut rebroadcast_queue = &this.inner.broadcasts; let mut relay = None; - match serf_proto::decode_message::(buf.as_ref()) { + match crate::types::decode_message::(buf.as_ref()) { Ok(msg) => { #[cfg(any(test, feature = "test"))] { @@ -405,7 +404,7 @@ where }; drop(members); - match serf_proto::Encodable::encode_to_bytes(&pp) { + match crate::types::Encodable::encode_to_bytes(&pp) { Ok(buf) => buf, Err(e) => { tracing::error!(err=%e, "serf: failed to encode local state"); @@ -421,7 +420,7 @@ where } // Check the message type - let msg = match serf_proto::decode_message::(buf) { + let msg = match crate::types::decode_message::(buf) { Ok(msg) => msg, Err(e) => { tracing::error!(err=%e, "serf: fail to decode remote state"); diff --git a/serf-core/src/serf/internal_query.rs b/serf-core/src/serf/internal_query.rs index 2861b97..9e2a23c 100644 --- a/serf-core/src/serf/internal_query.rs +++ b/serf-core/src/serf/internal_query.rs @@ -148,8 +148,8 @@ where // Encode the response match out { Some(state) => { - let resp = serf_proto::ConflictResponseMessageBorrow::from(state.member()); - match serf_proto::Encodable::encode_to_bytes(&resp) { + let resp = crate::types::ConflictResponseMessageBorrow::from(state.member()); + match crate::types::Encodable::encode_to_bytes(&resp) { Ok(raw) => { if let Err(e) = ev.respond(raw).await { tracing::error!(target="serf", err=%e, "failed to respond to conflict query"); @@ -179,7 +179,7 @@ where async fn handle_install_key(ev: impl AsRef> + Send) { let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = match serf_proto::decode_message::(&q.payload) { + let req = match crate::types::decode_message::(&q.payload) { Ok(msg) => match msg { MessageRef::KeyRequest(req) => req, msg => { @@ -242,7 +242,7 @@ where let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = match serf_proto::decode_message::(&q.payload) { + let req = match crate::types::decode_message::(&q.payload) { Ok(msg) => match msg { MessageRef::KeyRequest(req) => req, msg => { @@ -307,7 +307,7 @@ where let q = ev.as_ref(); let mut response = KeyResponseMessage::default(); - let req = match serf_proto::decode_message::(&q.payload) { + let req = match crate::types::decode_message::(&q.payload) { Ok(msg) => match msg { MessageRef::KeyRequest(req) => req, msg => { @@ -414,7 +414,7 @@ where ) -> Result< ( Bytes, - serf_proto::QueryResponseMessage, + crate::types::QueryResponseMessage, ), Error, > { @@ -426,12 +426,12 @@ where (q.ctx.this.inner.opts.query_response_size_limit / MIN_ENCODED_KEY_LENGTH).min(actual); for i in (0..=max_list_keys).rev() { - let kraw = serf_proto::Encodable::encode_to_bytes(&*resp)?; + let kraw = crate::types::Encodable::encode_to_bytes(&*resp)?; // create response let qresp = q.create_response(kraw.clone()); - let encoded_len = serf_proto::Encodable::encoded_len(&qresp); + let encoded_len = crate::types::Encodable::encoded_len(&qresp); // Check the size limit if q.check_response_size(encoded_len).is_err() { resp.keys.drain(i..); @@ -443,7 +443,7 @@ where } // encode response - let qraw = serf_proto::Encodable::encode_to_bytes(&qresp)?; + let qraw = crate::types::Encodable::encode_to_bytes(&qresp)?; if actual > i { tracing::warn!("serf: {}", resp.message); @@ -469,7 +469,7 @@ where tracing::error!(target="serf", err=%e, "failed to respond to key query"); } } - _ => match serf_proto::Encodable::encode_to_bytes(&*resp) { + _ => match crate::types::Encodable::encode_to_bytes(&*resp) { Ok(raw) => { if let Err(e) = q.respond(raw).await { tracing::error!(target="serf", err=%e, "failed to respond to key query"); diff --git a/serf-core/src/serf/query.rs b/serf-core/src/serf/query.rs index 17340f9..12b8676 100644 --- a/serf-core/src/serf/query.rs +++ b/serf-core/src/serf/query.rs @@ -4,6 +4,7 @@ use std::{ time::{Duration, Instant}, }; +use crate::types::FilterRef; use async_channel::{Receiver, Sender}; use async_lock::RwLock; use either::Either; @@ -15,7 +16,6 @@ use memberlist_core::{ tracing, transport::{Node, Transport}, }; -use serf_proto::FilterRef; use crate::{ delegate::Delegate, @@ -557,14 +557,14 @@ where } // Prep the relay message, which is a wrapped version of the original. - let encoded_len = serf_proto::Encodable::encoded_len_with_relay(&resp, &node); + let encoded_len = crate::types::Encodable::encoded_len_with_relay(&resp, &node); if encoded_len > self.inner.opts.query_response_size_limit { return Err(Error::relayed_response_too_large( self.inner.opts.query_response_size_limit, )); } - let raw = serf_proto::Encodable::encode_relay_to_bytes(&resp, &node)?; + let raw = crate::types::Encodable::encode_relay_to_bytes(&resp, &node)?; // Relay to a random set of peers. let relay_members = random_members(relay_factor as usize, members); diff --git a/serf-core/src/snapshot.rs b/serf-core/src/snapshot.rs index ee512cd..7edb9a8 100644 --- a/serf-core/src/snapshot.rs +++ b/serf-core/src/snapshot.rs @@ -11,6 +11,7 @@ use std::{ #[cfg(unix)] use std::os::unix::prelude::OpenOptionsExt; +use crate::types::UserEventMessage; use async_channel::{Receiver, Sender}; use byteorder::{LittleEndian, ReadBytesExt}; use futures::FutureExt; @@ -18,12 +19,11 @@ use memberlist_core::{ CheapClone, agnostic_lite::{AsyncSpawner, RuntimeLite}, bytes::{BufMut, BytesMut}, - proto::{Data, TinyVec}, + proto::{Data, MaybeResolvedAddress, TinyVec}, tracing, - transport::{Id, MaybeResolvedAddress, Node, Transport}, + transport::{Id, Node, Transport}, }; use rand::seq::SliceRandom; -use serf_proto::UserEventMessage; use crate::{ delegate::Delegate, @@ -446,7 +446,7 @@ where ) -> Result< ( Sender>, - TinyVec>>, + TinyVec>>, SnapshotHandle, ), SnapshotError, diff --git a/serf-core/src/types.rs b/serf-core/src/types.rs index 2e5d0b2..0a44972 100644 --- a/serf-core/src/types.rs +++ b/serf-core/src/types.rs @@ -1,9 +1,63 @@ -pub use serf_proto::*; +use std::time::Duration; + +pub use memberlist_core::proto::{ + DelegateVersion as MemberlistDelegateVersion, Domain, HostAddr, Node, NodeId, ParseDomainError, + ParseHostAddrError, ParseNodeIdError, ProtocolVersion as MemberlistProtocolVersion, +}; + +#[cfg(feature = "arbitrary")] +mod arbitrary_impl; + +mod clock; +pub use clock::*; + +mod conflict; +pub(crate) use conflict::*; + +mod filter; +pub(crate) use filter::*; + +mod leave; +pub(crate) use leave::*; mod member; -pub(crate) use member::*; +pub use member::*; -use std::time::Duration; +mod message; +pub(crate) use message::*; + +mod join; +pub(crate) use join::*; + +mod tags; +pub use tags::*; + +mod push_pull; +pub(crate) use push_pull::*; + +mod user_event; +pub(crate) use user_event::*; + +mod query; +pub(crate) use query::*; + +mod version; +pub use version::*; + +#[cfg(feature = "encryption")] +mod key; +#[cfg(feature = "encryption")] +#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] +pub use key::*; + +#[cfg(debug_assertions)] +#[inline] +fn debug_assert_write_eq(actual: usize, expected: usize) { + debug_assert_eq!( + actual, expected, + "expect writting {expected} bytes, but actual write {actual} bytes" + ); +} #[cfg(windows)] pub(crate) type Epoch = system_epoch::SystemTimeEpoch; diff --git a/serf-proto/src/arbitrary_impl.rs b/serf-core/src/types/arbitrary_impl.rs similarity index 98% rename from serf-proto/src/arbitrary_impl.rs rename to serf-core/src/types/arbitrary_impl.rs index d50f558..65a4518 100644 --- a/serf-proto/src/arbitrary_impl.rs +++ b/serf-core/src/types/arbitrary_impl.rs @@ -3,12 +3,10 @@ use std::{ hash::Hash, }; -use crate::TagFilter; - -use super::Filter; +use super::{Filter, TagFilter}; use arbitrary::{Arbitrary, Unstructured}; use indexmap::{IndexMap, IndexSet}; -use memberlist_proto::TinyVec; +use memberlist_core::proto::TinyVec; pub(super) fn into<'a, F, T>(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result where diff --git a/serf-proto/src/clock.rs b/serf-core/src/types/clock.rs similarity index 93% rename from serf-proto/src/clock.rs rename to serf-core/src/types/clock.rs index 84b8ffe..ba26c9b 100644 --- a/serf-proto/src/clock.rs +++ b/serf-core/src/types/clock.rs @@ -3,7 +3,7 @@ use std::sync::{ atomic::{AtomicU64, Ordering}, }; -use memberlist_proto::{Data, DataRef}; +use memberlist_core::proto::{Data, DataRef, DecodeError, EncodeError}; /// A lamport time is a simple u64 that represents a point in time. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -96,7 +96,7 @@ impl core::ops::Rem for LamportTime { impl Data for LamportTime { type Ref<'a> = Self; - fn from_ref(val: Self::Ref<'_>) -> Result + fn from_ref(val: Self::Ref<'_>) -> Result where Self: Sized, { @@ -107,13 +107,13 @@ impl Data for LamportTime { ::encoded_len(&self.0) } - fn encode(&self, buf: &mut [u8]) -> Result { + fn encode(&self, buf: &mut [u8]) -> Result { ::encode(&self.0, buf) } } impl<'a> DataRef<'a, LamportTime> for LamportTime { - fn decode(src: &'a [u8]) -> Result<(usize, LamportTime), memberlist_proto::DecodeError> { + fn decode(src: &'a [u8]) -> Result<(usize, LamportTime), DecodeError> { >::decode(src).map(|(n, v)| (n, v.into())) } } diff --git a/serf-proto/src/conflict.rs b/serf-core/src/types/conflict.rs similarity index 99% rename from serf-proto/src/conflict.rs rename to serf-core/src/types/conflict.rs index 087959a..b89bec4 100644 --- a/serf-proto/src/conflict.rs +++ b/serf-core/src/types/conflict.rs @@ -1,4 +1,4 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, utils::{merge, skip, split}, }; diff --git a/serf-proto/src/filter.rs b/serf-core/src/types/filter.rs similarity index 99% rename from serf-proto/src/filter.rs rename to serf-core/src/types/filter.rs index 8b88a52..52f38dd 100644 --- a/serf-proto/src/filter.rs +++ b/serf-core/src/types/filter.rs @@ -1,4 +1,4 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TinyVec, WireType, utils::{merge, skip, split}, }; diff --git a/serf-proto/src/filter/tag_filter.rs b/serf-core/src/types/filter/tag_filter.rs similarity index 98% rename from serf-proto/src/filter/tag_filter.rs rename to serf-core/src/types/filter/tag_filter.rs index fb241f9..e566800 100644 --- a/serf-proto/src/filter/tag_filter.rs +++ b/serf-core/src/types/filter/tag_filter.rs @@ -1,4 +1,4 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, utils::{merge, skip, split}, }; @@ -21,7 +21,7 @@ pub struct TagFilterRef<'a> { } impl<'a> DataRef<'a, TagFilter> for TagFilterRef<'a> { - fn decode(src: &'a [u8]) -> Result<(usize, Self), memberlist_proto::DecodeError> + fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> where Self: Sized, { diff --git a/serf-proto/src/join.rs b/serf-core/src/types/join.rs similarity index 97% rename from serf-proto/src/join.rs rename to serf-core/src/types/join.rs index 906ded5..4a0aff1 100644 --- a/serf-proto/src/join.rs +++ b/serf-core/src/types/join.rs @@ -1,4 +1,4 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, utils::{merge, skip, split}, }; @@ -126,7 +126,7 @@ where { type Ref<'a> = JoinMessage>; - fn from_ref(val: Self::Ref<'_>) -> Result + fn from_ref(val: Self::Ref<'_>) -> Result where Self: Sized, { diff --git a/serf-proto/src/key.rs b/serf-core/src/types/key.rs similarity index 99% rename from serf-proto/src/key.rs rename to serf-core/src/types/key.rs index 023faff..be1705a 100644 --- a/serf-proto/src/key.rs +++ b/serf-core/src/types/key.rs @@ -1,5 +1,5 @@ use indexmap::IndexMap; -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, SecretKey, SecretKeys, WireType, utils::{merge, skip, split}, }; diff --git a/serf-proto/src/leave.rs b/serf-core/src/types/leave.rs similarity index 99% rename from serf-proto/src/leave.rs rename to serf-core/src/types/leave.rs index 8d0f0b1..42e3d9e 100644 --- a/serf-proto/src/leave.rs +++ b/serf-core/src/types/leave.rs @@ -1,4 +1,4 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, utils::{merge, skip, split}, }; diff --git a/serf-core/src/types/member.rs b/serf-core/src/types/member.rs index 6400a81..a8c174b 100644 --- a/serf-core/src/types/member.rs +++ b/serf-core/src/types/member.rs @@ -1,9 +1,16 @@ -use memberlist_core::proto::OneOrMore; -use serf_proto::{Member, MessageType}; +use std::sync::Arc; -use std::collections::HashMap; +use memberlist_core::proto::{ + CheapClone, Data, DataRef, DecodeError, EncodeError, OneOrMore, WireType, + utils::{merge, skip, split}, +}; + +use super::{ + DelegateVersion, Epoch, LamportTime, MemberlistDelegateVersion, MemberlistProtocolVersion, + MessageType, Node, ProtocolVersion, Tags, TagsRef, +}; -use super::{Epoch, LamportTime}; +use std::collections::HashMap; /// Used to track members that are no longer active due to /// leaving, failing, partitioning, etc. It tracks the member along with @@ -43,3 +50,452 @@ impl Default for Members { } } } + +const MEMBER_STATUS_NONE: u8 = 0; +const MEMBER_STATUS_ALIVE: u8 = 1; +const MEMBER_STATUS_LEAVING: u8 = 2; +const MEMBER_STATUS_LEFT: u8 = 3; +const MEMBER_STATUS_FAILED: u8 = 4; + +/// The member status. +#[derive( + Debug, Default, Copy, Clone, Eq, PartialEq, Hash, derive_more::IsVariant, derive_more::Display, +)] +#[repr(u8)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[non_exhaustive] +pub enum MemberStatus { + /// None status + #[display("none")] + #[default] + None, + /// Alive status + #[display("alive")] + Alive, + /// Leaving status + #[display("leaving")] + Leaving, + /// Left status + #[display("left")] + Left, + /// Failed status + #[display("failed")] + Failed, + /// Unknown state (used for forwards and backwards compatibility) + #[display("unknown({_0})")] + Unknown(u8), +} + +impl From for MemberStatus { + fn from(value: u8) -> Self { + match value { + MEMBER_STATUS_NONE => Self::None, + MEMBER_STATUS_ALIVE => Self::Alive, + MEMBER_STATUS_LEAVING => Self::Leaving, + MEMBER_STATUS_LEFT => Self::Left, + MEMBER_STATUS_FAILED => Self::Failed, + val => Self::Unknown(val), + } + } +} + +impl From for u8 { + fn from(val: MemberStatus) -> Self { + match val { + MemberStatus::None => MEMBER_STATUS_NONE, + MemberStatus::Alive => MEMBER_STATUS_ALIVE, + MemberStatus::Leaving => MEMBER_STATUS_LEAVING, + MemberStatus::Left => MEMBER_STATUS_LEFT, + MemberStatus::Failed => MEMBER_STATUS_FAILED, + MemberStatus::Unknown(val) => val, + } + } +} + +impl MemberStatus { + /// Get the string representation of the member status + #[inline] + pub fn as_str(&self) -> std::borrow::Cow<'static, str> { + std::borrow::Cow::Borrowed(match self { + Self::None => "none", + Self::Alive => "alive", + Self::Leaving => "leaving", + Self::Left => "left", + Self::Failed => "failed", + Self::Unknown(val) => return format!("unknown({})", val).into(), + }) + } +} + +/// A single member of the Serf cluster. +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct Member { + /// The node + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the node")), + setter(attrs(doc = "Sets the node (Builder pattern)")) + )] + node: Node, + /// The tags + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the tags")), + setter(attrs(doc = "Sets the tags (Builder pattern)")) + )] + tags: Arc, + /// The status + #[viewit( + getter(const, style = "ref", attrs(doc = "Returns the status")), + setter(attrs(doc = "Sets the status (Builder pattern)")) + )] + status: MemberStatus, + /// The memberlist protocol version + #[viewit( + getter(const, attrs(doc = "Returns the memberlist protocol version")), + setter( + const, + attrs(doc = "Sets the memberlist protocol version (Builder pattern)") + ) + )] + memberlist_protocol_version: MemberlistProtocolVersion, + /// The memberlist delegate version + #[viewit( + getter(const, attrs(doc = "Returns the memberlist delegate version")), + setter( + const, + attrs(doc = "Sets the memberlist delegate version (Builder pattern)") + ) + )] + memberlist_delegate_version: MemberlistDelegateVersion, + + /// The serf protocol version + #[viewit( + getter(const, attrs(doc = "Returns the serf protocol version")), + setter(const, attrs(doc = "Sets the serf protocol version (Builder pattern)")) + )] + protocol_version: ProtocolVersion, + /// The serf delegate version + #[viewit( + getter(const, attrs(doc = "Returns the serf delegate version")), + setter(const, attrs(doc = "Sets the serf delegate version (Builder pattern)")) + )] + delegate_version: DelegateVersion, +} + +impl Member { + /// Create a new member with the given node, tags, and status. + /// Other fields are set to their default values. + #[inline] + pub fn new(node: Node, tags: Tags, status: MemberStatus) -> Self { + Self { + node, + tags: Arc::new(tags), + status, + memberlist_protocol_version: MemberlistProtocolVersion::V1, + memberlist_delegate_version: MemberlistDelegateVersion::V1, + protocol_version: ProtocolVersion::V1, + delegate_version: DelegateVersion::V1, + } + } +} + +impl Clone for Member { + fn clone(&self) -> Self { + Self { + node: self.node.clone(), + tags: self.tags.clone(), + status: self.status, + memberlist_protocol_version: self.memberlist_protocol_version, + memberlist_delegate_version: self.memberlist_delegate_version, + protocol_version: self.protocol_version, + delegate_version: self.delegate_version, + } + } +} + +impl CheapClone for Member { + fn cheap_clone(&self) -> Self { + Self { + node: self.node.cheap_clone(), + tags: self.tags.cheap_clone(), + status: self.status, + memberlist_protocol_version: self.memberlist_protocol_version, + memberlist_delegate_version: self.memberlist_delegate_version, + protocol_version: self.protocol_version, + delegate_version: self.delegate_version, + } + } +} + +const NODE_TAG: u8 = 1; +const TAGS_TAG: u8 = 2; +const STATUS_TAG: u8 = 3; +const MEMBERLIST_PROTOCOL_VERSION_TAG: u8 = 4; +const MEMBERLIST_DELEGATE_VERSION_TAG: u8 = 5; +const PROTOCOL_VERSION_TAG: u8 = 6; +const DELEGATE_VERSION_TAG: u8 = 7; + +const NODE_BYTE: u8 = merge(WireType::LengthDelimited, NODE_TAG); +const TAGS_BYTE: u8 = merge(WireType::LengthDelimited, TAGS_TAG); +const STATUS_BYTE: u8 = merge(WireType::Byte, STATUS_TAG); +const MEMBERLIST_PROTOCOL_VERSION_BYTE: u8 = merge(WireType::Byte, MEMBERLIST_PROTOCOL_VERSION_TAG); +const MEMBERLIST_DELEGATE_VERSION_BYTE: u8 = merge(WireType::Byte, MEMBERLIST_DELEGATE_VERSION_TAG); +const PROTOCOL_VERSION_BYTE: u8 = merge(WireType::Byte, PROTOCOL_VERSION_TAG); +const DELEGATE_VERSION_BYTE: u8 = merge(WireType::Byte, DELEGATE_VERSION_TAG); + +/// A reference type to [`Member`] +#[viewit::viewit(vis_all = "", getters(vis_all = "pub"), setters(skip))] +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct MemberRef<'a, I, A> { + /// The node + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the node")))] + node: Node, + /// The tags + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the tags")))] + tags: TagsRef<'a>, + /// The status + #[viewit(getter(const, style = "ref", attrs(doc = "Returns the status")))] + status: MemberStatus, + /// The memberlist protocol version + #[viewit(getter(const, attrs(doc = "Returns the memberlist protocol version")))] + memberlist_protocol_version: MemberlistProtocolVersion, + /// The memberlist delegate version + #[viewit(getter(const, attrs(doc = "Returns the memberlist delegate version")))] + memberlist_delegate_version: MemberlistDelegateVersion, + /// The serf protocol version + #[viewit(getter(const, attrs(doc = "Returns the serf protocol version")))] + protocol_version: ProtocolVersion, + /// The serf delegate version + #[viewit(getter(const, attrs(doc = "Returns the serf delegate version")))] + delegate_version: DelegateVersion, +} + +impl<'a, I, A> DataRef<'a, Member> for MemberRef<'a, I::Ref<'a>, A::Ref<'a>> +where + I: Data, + A: Data, +{ + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); + + let mut node = None; + let mut tags = None; + let mut status = None; + let mut memberlist_protocol_version = None; + let mut memberlist_delegate_version = None; + let mut protocol_version = None; + let mut delegate_version = None; + + while offset < buf_len { + match buf[offset] { + NODE_BYTE => { + if node.is_some() { + return Err(DecodeError::duplicate_field("Member", "node", NODE_TAG)); + } + offset += 1; + let (size, val) = + , A::Ref<'_>> as DataRef<'_, Node>>::decode_length_delimited( + &buf[offset..], + )?; + node = Some(val); + offset += size; + } + TAGS_BYTE => { + if tags.is_some() { + return Err(DecodeError::duplicate_field("Member", "tags", TAGS_TAG)); + } + offset += 1; + let (size, val) = + as DataRef<'_, Tags>>::decode_length_delimited(&buf[offset..])?; + tags = Some(val); + offset += size; + } + STATUS_BYTE => { + if status.is_some() { + return Err(DecodeError::duplicate_field("Member", "status", STATUS_TAG)); + } + offset += 1; + status = Some(buf[offset].into()); + offset += 1; + } + MEMBERLIST_PROTOCOL_VERSION_BYTE => { + if memberlist_protocol_version.is_some() { + return Err(DecodeError::duplicate_field( + "Member", + "memberlist_protocol_version", + MEMBERLIST_PROTOCOL_VERSION_TAG, + )); + } + offset += 1; + memberlist_protocol_version = Some(buf[offset].into()); + offset += 1; + } + MEMBERLIST_DELEGATE_VERSION_BYTE => { + if memberlist_delegate_version.is_some() { + return Err(DecodeError::duplicate_field( + "Member", + "memberlist_delegate_version", + MEMBERLIST_DELEGATE_VERSION_TAG, + )); + } + offset += 1; + memberlist_delegate_version = Some(buf[offset].into()); + offset += 1; + } + PROTOCOL_VERSION_BYTE => { + if protocol_version.is_some() { + return Err(DecodeError::duplicate_field( + "Member", + "protocol_version", + PROTOCOL_VERSION_TAG, + )); + } + offset += 1; + protocol_version = Some(buf[offset].into()); + offset += 1; + } + DELEGATE_VERSION_BYTE => { + if delegate_version.is_some() { + return Err(DecodeError::duplicate_field( + "Member", + "delegate_version", + DELEGATE_VERSION_TAG, + )); + } + offset += 1; + delegate_version = Some(buf[offset].into()); + offset += 1; + } + other => { + offset += 1; + + let (wire_type, _) = split(other); + let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + offset += skip(wire_type, &buf[offset..])?; + } + } + } + + Ok(( + offset, + Self { + node: node.ok_or_else(|| DecodeError::missing_field("Member", "node"))?, + tags: tags.ok_or_else(|| DecodeError::missing_field("Member", "tags"))?, + status: status.ok_or_else(|| DecodeError::missing_field("Member", "status"))?, + memberlist_protocol_version: memberlist_protocol_version + .ok_or_else(|| DecodeError::missing_field("Member", "memberlist_protocol_version"))?, + memberlist_delegate_version: memberlist_delegate_version + .ok_or_else(|| DecodeError::missing_field("Member", "memberlist_delegate_version"))?, + protocol_version: protocol_version + .ok_or_else(|| DecodeError::missing_field("Member", "protocol_version"))?, + delegate_version: delegate_version + .ok_or_else(|| DecodeError::missing_field("Member", "delegate_version"))?, + }, + )) + } +} + +impl Data for Member +where + I: Data, + A: Data, +{ + type Ref<'a> = MemberRef<'a, I::Ref<'a>, A::Ref<'a>>; + + fn from_ref(val: Self::Ref<'_>) -> Result + where + Self: Sized, + { + Ok(Self { + node: Node::from_ref(val.node)?, + tags: Tags::from_ref(val.tags)?.into(), + status: val.status, + memberlist_protocol_version: val.memberlist_protocol_version, + memberlist_delegate_version: val.memberlist_delegate_version, + protocol_version: val.protocol_version, + delegate_version: val.delegate_version, + }) + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + len += 1 + self.node.encoded_len_with_length_delimited(); + len += 1 + self.tags.encoded_len_with_length_delimited(); + len += 1 + 1; // status + len += 1 + 1; // memberlist_protocol_version + len += 1 + 1; // memberlist_delegate_version + len += 1 + 1; // protocol_version + len += 1 + 1; // delegate_version + len + } + + fn encode(&self, buf: &mut [u8]) -> Result { + macro_rules! bail { + ($this:ident($offset:expr, $len:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); + } + }; + } + + let buf_len = buf.len(); + let mut offset = 0; + bail!(self(offset, buf_len)); + + buf[offset] = NODE_BYTE; + offset += 1; + offset += self.node.encode_length_delimited(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = TAGS_BYTE; + offset += 1; + offset += self.tags.encode_length_delimited(&mut buf[offset..])?; + + bail!(self(offset, buf_len)); + buf[offset] = STATUS_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.status.into(); + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = MEMBERLIST_PROTOCOL_VERSION_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.memberlist_protocol_version.into(); + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = MEMBERLIST_DELEGATE_VERSION_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.memberlist_delegate_version.into(); + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = PROTOCOL_VERSION_BYTE; + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = self.protocol_version.into(); + offset += 1; + + bail!(self(offset, buf_len)); + buf[offset] = DELEGATE_VERSION_BYTE; + offset += 1; + + #[cfg(debug_assertions)] + super::debug_assert_write_eq(offset, self.encoded_len()); + + Ok(offset) + } +} diff --git a/serf-proto/src/message.rs b/serf-core/src/types/message.rs similarity index 98% rename from serf-proto/src/message.rs rename to serf-core/src/types/message.rs index bb2d87f..9d86d59 100644 --- a/serf-proto/src/message.rs +++ b/serf-core/src/types/message.rs @@ -1,17 +1,14 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, Node, WireType, bytes::Bytes, utils::{merge, skip, split}, }; -use crate::{ - ConflictResponseMessageRef, PushPullMessage, PushPullMessageRef, QueryMessageRef, - QueryResponseMessageRef, UserEventMessageRef, -}; - use super::{ - ConflictResponseMessage, ConflictResponseMessageBorrow, JoinMessage, LeaveMessage, - PushPullMessageBorrow, QueryMessage, QueryResponseMessage, UserEventMessage, + ConflictResponseMessage, ConflictResponseMessageBorrow, ConflictResponseMessageRef, JoinMessage, + LeaveMessage, PushPullMessage, PushPullMessageBorrow, PushPullMessageRef, QueryMessage, + QueryMessageRef, QueryResponseMessage, QueryResponseMessageRef, UserEventMessage, + UserEventMessageRef, }; #[cfg(feature = "encryption")] diff --git a/serf-proto/src/push_pull.rs b/serf-core/src/types/push_pull.rs similarity index 97% rename from serf-proto/src/push_pull.rs rename to serf-core/src/types/push_pull.rs index d748bfd..4383a86 100644 --- a/serf-proto/src/push_pull.rs +++ b/serf-core/src/types/push_pull.rs @@ -1,5 +1,5 @@ use indexmap::{IndexMap, IndexSet}; -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TinyVec, TupleEncoder, WireType, utils::{merge, skip, split}, }; @@ -39,14 +39,14 @@ pub struct PushPullMessage { ), setter(attrs(doc = "Sets the maps the node to its status time (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::arbitrary_indexmap))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::arbitrary_indexmap))] status_ltimes: IndexMap, /// List of left nodes #[viewit( getter(const, style = "ref", attrs(doc = "Returns the list of left nodes")), setter(attrs(doc = "Sets the list of left nodes (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::arbitrary_indexset))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::arbitrary_indexset))] left_members: IndexSet, /// Lamport time for event clock #[viewit( @@ -66,7 +66,7 @@ pub struct PushPullMessage { getter(const, style = "ref", attrs(doc = "Returns the recent events")), setter(attrs(doc = "Sets the recent events (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, TinyVec>))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::, TinyVec>))] events: TinyVec, /// Lamport time for query clock #[viewit( diff --git a/serf-proto/src/query.rs b/serf-core/src/types/query.rs similarity index 98% rename from serf-proto/src/query.rs rename to serf-core/src/types/query.rs index 8983943..f4935c0 100644 --- a/serf-proto/src/query.rs +++ b/serf-core/src/types/query.rs @@ -2,7 +2,7 @@ use smol_str::SmolStr; use std::time::Duration; -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, Node, RepeatedDecoder, TinyVec, WireType, bytes::Bytes, utils::{merge, skip, split}, @@ -77,7 +77,7 @@ pub struct QueryMessage { getter(const, attrs(doc = "Returns the potential query filters")), setter(attrs(doc = "Sets the potential query filters (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::>, TinyVec>>))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::>, TinyVec>>))] filters: TinyVec>, /// Used to provide various flags #[viewit( @@ -116,7 +116,7 @@ pub struct QueryMessage { getter(const, style = "ref", attrs(doc = "Returns the payload")), setter(attrs(doc = "Sets the payload (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::, Bytes>))] payload: Bytes, } diff --git a/serf-proto/src/query/response.rs b/serf-core/src/types/query/response.rs similarity index 98% rename from serf-proto/src/query/response.rs rename to serf-core/src/types/query/response.rs index 48bfbbf..999ef30 100644 --- a/serf-proto/src/query/response.rs +++ b/serf-core/src/types/query/response.rs @@ -1,12 +1,10 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, Node, WireType, bytes::Bytes, utils::{merge, skip, split}, }; -use crate::LamportTime; - -use super::QueryFlag; +use super::{LamportTime, QueryFlag}; const LTIME_TAG: u8 = 1; const ID_TAG: u8 = 2; @@ -58,7 +56,7 @@ pub struct QueryResponseMessage { getter(const, style = "ref", attrs(doc = "Returns the payload")), setter(attrs(doc = "Sets the payload (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::, Bytes>))] payload: Bytes, } diff --git a/serf-proto/src/tags.rs b/serf-core/src/types/tags.rs similarity index 94% rename from serf-proto/src/tags.rs rename to serf-core/src/types/tags.rs index 7d52fb9..1725959 100644 --- a/serf-proto/src/tags.rs +++ b/serf-core/src/types/tags.rs @@ -1,5 +1,5 @@ use indexmap::IndexMap; -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TupleEncoder, WireType, utils::{merge, skip, split}, }; @@ -23,8 +23,8 @@ const TAGS_BYTE: u8 = merge(WireType::LengthDelimited, TAGS_TAG); #[cfg_attr(feature = "serde", serde(transparent))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct Tags( - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::arbitrary_indexmap))] - IndexMap, + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::arbitrary_indexmap))] + IndexMap, ); impl IntoIterator for Tags { @@ -73,7 +73,7 @@ pub struct TagsRef<'a> { } impl<'a> DataRef<'a, Tags> for TagsRef<'a> { - fn decode(src: &'a [u8]) -> Result<(usize, Self), memberlist_proto::DecodeError> + fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError> where Self: Sized, { diff --git a/serf-proto/src/user_event.rs b/serf-core/src/types/user_event.rs similarity index 96% rename from serf-proto/src/user_event.rs rename to serf-core/src/types/user_event.rs index a929286..ef63bd8 100644 --- a/serf-proto/src/user_event.rs +++ b/serf-core/src/types/user_event.rs @@ -1,4 +1,4 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, bytes::Bytes, utils::{merge, skip, split}, @@ -34,7 +34,7 @@ pub struct UserEvent { getter(const, attrs(doc = "Returns the payload of the event")), setter(attrs(doc = "Sets the payload of the event (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::, Bytes>))] payload: Bytes, } diff --git a/serf-proto/src/user_event/message.rs b/serf-core/src/types/user_event/message.rs similarity index 97% rename from serf-proto/src/user_event/message.rs rename to serf-core/src/types/user_event/message.rs index 182cec5..ff9724b 100644 --- a/serf-proto/src/user_event/message.rs +++ b/serf-core/src/types/user_event/message.rs @@ -1,11 +1,11 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ CheapClone, Data, DataRef, DecodeError, EncodeError, WireType, bytes::Bytes, utils::{merge, skip, split}, }; use smol_str::SmolStr; -use crate::LamportTime; +use super::super::LamportTime; /// Used for user-generated events #[viewit::viewit(setters(prefix = "with"))] @@ -37,7 +37,7 @@ pub struct UserEventMessage { getter(const, attrs(doc = "Returns the payload of the event")), setter(attrs(doc = "Sets the payload of the event (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, Bytes>))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::, Bytes>))] payload: Bytes, /// "Can Coalesce". #[viewit( diff --git a/serf-proto/src/user_event/user_events.rs b/serf-core/src/types/user_event/user_events.rs similarity index 95% rename from serf-proto/src/user_event/user_events.rs rename to serf-core/src/types/user_event/user_events.rs index 17f02d1..3fe019a 100644 --- a/serf-proto/src/user_event/user_events.rs +++ b/serf-core/src/types/user_event/user_events.rs @@ -1,11 +1,9 @@ -use memberlist_proto::{ +use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, OneOrMore, RepeatedDecoder, WireType, utils::{merge, skip, split}, }; -use crate::LamportTime; - -use super::UserEvent; +use super::{super::LamportTime, UserEvent}; const LTIME_TAG: u8 = 1; const EVENTS_TAG: u8 = 2; @@ -34,7 +32,7 @@ pub struct UserEvents { getter(const, style = "ref", attrs(doc = "Returns the user events")), setter(attrs(doc = "Sets the user events (Builder pattern)")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::arbitrary_impl::into::, OneOrMore>))] + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::, OneOrMore>))] events: OneOrMore, } diff --git a/serf-proto/src/version.rs b/serf-core/src/types/version.rs similarity index 97% rename from serf-proto/src/version.rs rename to serf-core/src/types/version.rs index bb110a6..710d49f 100644 --- a/serf-proto/src/version.rs +++ b/serf-core/src/types/version.rs @@ -1,4 +1,4 @@ -use memberlist_proto::{Data, DataRef, DecodeError, EncodeError, WireType}; +use memberlist_core::proto::{Data, DataRef, DecodeError, EncodeError, WireType}; /// Delegate version #[derive( diff --git a/serf-proto/Cargo.toml b/serf-proto/Cargo.toml deleted file mode 100644 index cec8ff9..0000000 --- a/serf-proto/Cargo.toml +++ /dev/null @@ -1,43 +0,0 @@ -[package] -name = "serf-proto" -version = "0.1.0" -rust-version.workspace = true -edition.workspace = true -repository.workspace = true -homepage.workspace = true -license.workspace = true -description = "Types for the `serf` crate" - -[features] -encryption = ["memberlist-proto/encryption", "futures"] -serde = ["dep:serde", "indexmap/serde", "memberlist-proto/serde", "smol_str/serde", "bitflags/serde"] -metrics = ["memberlist-proto/metrics"] - -arbitrary = ["dep:arbitrary", "memberlist-proto/arbitrary", "smol_str/arbitrary"] -quickcheck = ["dep:quickcheck", "memberlist-proto/quickcheck"] - -[dependencies] -bitflags = "2" -byteorder.workspace = true -bytemuck = { version = "1", features = ["derive"] } -derive_more = { workspace = true, features = ["is_variant", "display", "unwrap", "try_unwrap"] } -futures = { workspace = true, optional = true, features = ["alloc"] } -indexmap.workspace = true -memberlist-proto.workspace = true -regex.workspace = true -smol_str.workspace = true -thiserror.workspace = true -viewit.workspace = true - -serde = { workspace = true, optional = true } - -arbitrary = { version = "1", optional = true, default-features = false, features = ["derive"] } -quickcheck = { version = "1", optional = true, default-features = false } - -[dev-dependencies] -rand.workspace = true -futures = { workspace = true, features = ["executor"] } - -[package.metadata.docs.rs] -all-features = true -rustdoc-args = ["--cfg", "docsrs"] \ No newline at end of file diff --git a/serf-proto/src/lib.rs b/serf-proto/src/lib.rs deleted file mode 100644 index 8c42dc2..0000000 --- a/serf-proto/src/lib.rs +++ /dev/null @@ -1,75 +0,0 @@ -//! Types used by the [`serf`](https://crates.io/crates/serf) crate. -#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/memberlist/main/art/logo_72x72.png")] -#![forbid(unsafe_code)] -#![deny(warnings, missing_docs)] -#![allow(clippy::type_complexity, unexpected_cfgs)] -#![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(docsrs, allow(unused_attributes))] - -pub use memberlist_proto::{ - DelegateVersion as MemberlistDelegateVersion, Domain, HostAddr, Node, NodeId, ParseDomainError, - ParseHostAddrError, ProtocolVersion as MemberlistProtocolVersion, -}; - -#[cfg(feature = "arbitrary")] -mod arbitrary_impl; - -mod clock; -pub use clock::*; - -mod conflict; -pub use conflict::*; - -mod filter; -pub use filter::*; - -mod leave; -pub use leave::*; - -mod member; -pub use member::*; - -mod message; -pub use message::*; - -mod join; -pub use join::*; - -mod tags; -pub use tags::*; - -mod push_pull; -pub use push_pull::*; - -mod user_event; -pub use user_event::*; - -mod query; -pub use query::*; - -mod version; -pub use version::*; - -#[cfg(feature = "encryption")] -mod key; -#[cfg(feature = "encryption")] -#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] -pub use key::*; - -#[cfg(debug_assertions)] -#[inline] -fn debug_assert_write_eq(actual: usize, expected: usize) { - debug_assert_eq!( - actual, expected, - "expect writting {expected} bytes, but actual write {actual} bytes" - ); -} - -// #[cfg(debug_assertions)] -// #[inline] -// fn debug_assert_read_eq(actual: usize, expected: usize) { -// debug_assert_eq!( -// actual, expected, -// "expect reading {expected} bytes, but actual read {actual} bytes" -// ); -// } diff --git a/serf-proto/src/member.rs b/serf-proto/src/member.rs deleted file mode 100644 index a3d12f4..0000000 --- a/serf-proto/src/member.rs +++ /dev/null @@ -1,474 +0,0 @@ -use std::sync::Arc; - -use memberlist_proto::{ - CheapClone, Data, DataRef, EncodeError, WireType, - utils::{merge, skip, split}, -}; - -use crate::TagsRef; - -use super::{ - DelegateVersion, MemberlistDelegateVersion, MemberlistProtocolVersion, Node, ProtocolVersion, - Tags, -}; - -const MEMBER_STATUS_NONE: u8 = 0; -const MEMBER_STATUS_ALIVE: u8 = 1; -const MEMBER_STATUS_LEAVING: u8 = 2; -const MEMBER_STATUS_LEFT: u8 = 3; -const MEMBER_STATUS_FAILED: u8 = 4; - -/// The member status. -#[derive( - Debug, Default, Copy, Clone, Eq, PartialEq, Hash, derive_more::IsVariant, derive_more::Display, -)] -#[repr(u8)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] -pub enum MemberStatus { - /// None status - #[display("none")] - #[default] - None, - /// Alive status - #[display("alive")] - Alive, - /// Leaving status - #[display("leaving")] - Leaving, - /// Left status - #[display("left")] - Left, - /// Failed status - #[display("failed")] - Failed, - /// Unknown state (used for forwards and backwards compatibility) - #[display("unknown({_0})")] - Unknown(u8), -} - -impl From for MemberStatus { - fn from(value: u8) -> Self { - match value { - MEMBER_STATUS_NONE => Self::None, - MEMBER_STATUS_ALIVE => Self::Alive, - MEMBER_STATUS_LEAVING => Self::Leaving, - MEMBER_STATUS_LEFT => Self::Left, - MEMBER_STATUS_FAILED => Self::Failed, - val => Self::Unknown(val), - } - } -} - -impl From for u8 { - fn from(val: MemberStatus) -> Self { - match val { - MemberStatus::None => MEMBER_STATUS_NONE, - MemberStatus::Alive => MEMBER_STATUS_ALIVE, - MemberStatus::Leaving => MEMBER_STATUS_LEAVING, - MemberStatus::Left => MEMBER_STATUS_LEFT, - MemberStatus::Failed => MEMBER_STATUS_FAILED, - MemberStatus::Unknown(val) => val, - } - } -} - -impl MemberStatus { - /// Get the string representation of the member status - #[inline] - pub fn as_str(&self) -> std::borrow::Cow<'static, str> { - std::borrow::Cow::Borrowed(match self { - Self::None => "none", - Self::Alive => "alive", - Self::Leaving => "leaving", - Self::Left => "left", - Self::Failed => "failed", - Self::Unknown(val) => return format!("unknown({})", val).into(), - }) - } -} - -/// A single member of the Serf cluster. -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] -pub struct Member { - /// The node - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the node")), - setter(attrs(doc = "Sets the node (Builder pattern)")) - )] - node: Node, - /// The tags - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the tags")), - setter(attrs(doc = "Sets the tags (Builder pattern)")) - )] - tags: Arc, - /// The status - #[viewit( - getter(const, style = "ref", attrs(doc = "Returns the status")), - setter(attrs(doc = "Sets the status (Builder pattern)")) - )] - status: MemberStatus, - /// The memberlist protocol version - #[viewit( - getter(const, attrs(doc = "Returns the memberlist protocol version")), - setter( - const, - attrs(doc = "Sets the memberlist protocol version (Builder pattern)") - ) - )] - memberlist_protocol_version: MemberlistProtocolVersion, - /// The memberlist delegate version - #[viewit( - getter(const, attrs(doc = "Returns the memberlist delegate version")), - setter( - const, - attrs(doc = "Sets the memberlist delegate version (Builder pattern)") - ) - )] - memberlist_delegate_version: MemberlistDelegateVersion, - - /// The serf protocol version - #[viewit( - getter(const, attrs(doc = "Returns the serf protocol version")), - setter(const, attrs(doc = "Sets the serf protocol version (Builder pattern)")) - )] - protocol_version: ProtocolVersion, - /// The serf delegate version - #[viewit( - getter(const, attrs(doc = "Returns the serf delegate version")), - setter(const, attrs(doc = "Sets the serf delegate version (Builder pattern)")) - )] - delegate_version: DelegateVersion, -} - -impl Member { - /// Create a new member with the given node, tags, and status. - /// Other fields are set to their default values. - #[inline] - pub fn new(node: Node, tags: Tags, status: MemberStatus) -> Self { - Self { - node, - tags: Arc::new(tags), - status, - memberlist_protocol_version: MemberlistProtocolVersion::V1, - memberlist_delegate_version: MemberlistDelegateVersion::V1, - protocol_version: ProtocolVersion::V1, - delegate_version: DelegateVersion::V1, - } - } -} - -impl Clone for Member { - fn clone(&self) -> Self { - Self { - node: self.node.clone(), - tags: self.tags.clone(), - status: self.status, - memberlist_protocol_version: self.memberlist_protocol_version, - memberlist_delegate_version: self.memberlist_delegate_version, - protocol_version: self.protocol_version, - delegate_version: self.delegate_version, - } - } -} - -impl CheapClone for Member { - fn cheap_clone(&self) -> Self { - Self { - node: self.node.cheap_clone(), - tags: self.tags.cheap_clone(), - status: self.status, - memberlist_protocol_version: self.memberlist_protocol_version, - memberlist_delegate_version: self.memberlist_delegate_version, - protocol_version: self.protocol_version, - delegate_version: self.delegate_version, - } - } -} - -const NODE_TAG: u8 = 1; -const TAGS_TAG: u8 = 2; -const STATUS_TAG: u8 = 3; -const MEMBERLIST_PROTOCOL_VERSION_TAG: u8 = 4; -const MEMBERLIST_DELEGATE_VERSION_TAG: u8 = 5; -const PROTOCOL_VERSION_TAG: u8 = 6; -const DELEGATE_VERSION_TAG: u8 = 7; - -const NODE_BYTE: u8 = merge(WireType::LengthDelimited, NODE_TAG); -const TAGS_BYTE: u8 = merge(WireType::LengthDelimited, TAGS_TAG); -const STATUS_BYTE: u8 = merge(WireType::Byte, STATUS_TAG); -const MEMBERLIST_PROTOCOL_VERSION_BYTE: u8 = merge(WireType::Byte, MEMBERLIST_PROTOCOL_VERSION_TAG); -const MEMBERLIST_DELEGATE_VERSION_BYTE: u8 = merge(WireType::Byte, MEMBERLIST_DELEGATE_VERSION_TAG); -const PROTOCOL_VERSION_BYTE: u8 = merge(WireType::Byte, PROTOCOL_VERSION_TAG); -const DELEGATE_VERSION_BYTE: u8 = merge(WireType::Byte, DELEGATE_VERSION_TAG); - -/// A reference type to [`Member`] -#[viewit::viewit(vis_all = "", getters(vis_all = "pub"), setters(skip))] -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct MemberRef<'a, I, A> { - /// The node - #[viewit(getter(const, style = "ref", attrs(doc = "Returns the node")))] - node: Node, - /// The tags - #[viewit(getter(const, style = "ref", attrs(doc = "Returns the tags")))] - tags: TagsRef<'a>, - /// The status - #[viewit(getter(const, style = "ref", attrs(doc = "Returns the status")))] - status: MemberStatus, - /// The memberlist protocol version - #[viewit(getter(const, attrs(doc = "Returns the memberlist protocol version")))] - memberlist_protocol_version: MemberlistProtocolVersion, - /// The memberlist delegate version - #[viewit(getter(const, attrs(doc = "Returns the memberlist delegate version")))] - memberlist_delegate_version: MemberlistDelegateVersion, - /// The serf protocol version - #[viewit(getter(const, attrs(doc = "Returns the serf protocol version")))] - protocol_version: ProtocolVersion, - /// The serf delegate version - #[viewit(getter(const, attrs(doc = "Returns the serf delegate version")))] - delegate_version: DelegateVersion, -} - -impl<'a, I, A> DataRef<'a, Member> for MemberRef<'a, I::Ref<'a>, A::Ref<'a>> -where - I: Data, - A: Data, -{ - fn decode(buf: &'a [u8]) -> Result<(usize, Self), memberlist_proto::DecodeError> - where - Self: Sized, - { - let mut offset = 0; - let buf_len = buf.len(); - - let mut node = None; - let mut tags = None; - let mut status = None; - let mut memberlist_protocol_version = None; - let mut memberlist_delegate_version = None; - let mut protocol_version = None; - let mut delegate_version = None; - - while offset < buf_len { - match buf[offset] { - NODE_BYTE => { - if node.is_some() { - return Err(memberlist_proto::DecodeError::duplicate_field( - "Member", "node", NODE_TAG, - )); - } - offset += 1; - let (size, val) = - , A::Ref<'_>> as DataRef<'_, Node>>::decode_length_delimited( - &buf[offset..], - )?; - node = Some(val); - offset += size; - } - TAGS_BYTE => { - if tags.is_some() { - return Err(memberlist_proto::DecodeError::duplicate_field( - "Member", "tags", TAGS_TAG, - )); - } - offset += 1; - let (size, val) = - as DataRef<'_, Tags>>::decode_length_delimited(&buf[offset..])?; - tags = Some(val); - offset += size; - } - STATUS_BYTE => { - if status.is_some() { - return Err(memberlist_proto::DecodeError::duplicate_field( - "Member", "status", STATUS_TAG, - )); - } - offset += 1; - status = Some(buf[offset].into()); - offset += 1; - } - MEMBERLIST_PROTOCOL_VERSION_BYTE => { - if memberlist_protocol_version.is_some() { - return Err(memberlist_proto::DecodeError::duplicate_field( - "Member", - "memberlist_protocol_version", - MEMBERLIST_PROTOCOL_VERSION_TAG, - )); - } - offset += 1; - memberlist_protocol_version = Some(buf[offset].into()); - offset += 1; - } - MEMBERLIST_DELEGATE_VERSION_BYTE => { - if memberlist_delegate_version.is_some() { - return Err(memberlist_proto::DecodeError::duplicate_field( - "Member", - "memberlist_delegate_version", - MEMBERLIST_DELEGATE_VERSION_TAG, - )); - } - offset += 1; - memberlist_delegate_version = Some(buf[offset].into()); - offset += 1; - } - PROTOCOL_VERSION_BYTE => { - if protocol_version.is_some() { - return Err(memberlist_proto::DecodeError::duplicate_field( - "Member", - "protocol_version", - PROTOCOL_VERSION_TAG, - )); - } - offset += 1; - protocol_version = Some(buf[offset].into()); - offset += 1; - } - DELEGATE_VERSION_BYTE => { - if delegate_version.is_some() { - return Err(memberlist_proto::DecodeError::duplicate_field( - "Member", - "delegate_version", - DELEGATE_VERSION_TAG, - )); - } - offset += 1; - delegate_version = Some(buf[offset].into()); - offset += 1; - } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(memberlist_proto::DecodeError::unknown_wire_type)?; - offset += skip(wire_type, &buf[offset..])?; - } - } - } - - Ok(( - offset, - Self { - node: node.ok_or_else(|| memberlist_proto::DecodeError::missing_field("Member", "node"))?, - tags: tags.ok_or_else(|| memberlist_proto::DecodeError::missing_field("Member", "tags"))?, - status: status - .ok_or_else(|| memberlist_proto::DecodeError::missing_field("Member", "status"))?, - memberlist_protocol_version: memberlist_protocol_version.ok_or_else(|| { - memberlist_proto::DecodeError::missing_field("Member", "memberlist_protocol_version") - })?, - memberlist_delegate_version: memberlist_delegate_version.ok_or_else(|| { - memberlist_proto::DecodeError::missing_field("Member", "memberlist_delegate_version") - })?, - protocol_version: protocol_version.ok_or_else(|| { - memberlist_proto::DecodeError::missing_field("Member", "protocol_version") - })?, - delegate_version: delegate_version.ok_or_else(|| { - memberlist_proto::DecodeError::missing_field("Member", "delegate_version") - })?, - }, - )) - } -} - -impl Data for Member -where - I: Data, - A: Data, -{ - type Ref<'a> = MemberRef<'a, I::Ref<'a>, A::Ref<'a>>; - - fn from_ref(val: Self::Ref<'_>) -> Result - where - Self: Sized, - { - Ok(Self { - node: Node::from_ref(val.node)?, - tags: Tags::from_ref(val.tags)?.into(), - status: val.status, - memberlist_protocol_version: val.memberlist_protocol_version, - memberlist_delegate_version: val.memberlist_delegate_version, - protocol_version: val.protocol_version, - delegate_version: val.delegate_version, - }) - } - - fn encoded_len(&self) -> usize { - let mut len = 0; - len += 1 + self.node.encoded_len_with_length_delimited(); - len += 1 + self.tags.encoded_len_with_length_delimited(); - len += 1 + 1; // status - len += 1 + 1; // memberlist_protocol_version - len += 1 + 1; // memberlist_delegate_version - len += 1 + 1; // protocol_version - len += 1 + 1; // delegate_version - len - } - - fn encode(&self, buf: &mut [u8]) -> Result { - macro_rules! bail { - ($this:ident($offset:expr, $len:ident)) => { - if $offset >= $len { - return Err(EncodeError::insufficient_buffer(self.encoded_len(), $len)); - } - }; - } - - let buf_len = buf.len(); - let mut offset = 0; - bail!(self(offset, buf_len)); - - buf[offset] = NODE_BYTE; - offset += 1; - offset += self.node.encode_length_delimited(&mut buf[offset..])?; - - bail!(self(offset, buf_len)); - buf[offset] = TAGS_BYTE; - offset += 1; - offset += self.tags.encode_length_delimited(&mut buf[offset..])?; - - bail!(self(offset, buf_len)); - buf[offset] = STATUS_BYTE; - offset += 1; - - bail!(self(offset, buf_len)); - buf[offset] = self.status.into(); - offset += 1; - - bail!(self(offset, buf_len)); - buf[offset] = MEMBERLIST_PROTOCOL_VERSION_BYTE; - offset += 1; - - bail!(self(offset, buf_len)); - buf[offset] = self.memberlist_protocol_version.into(); - offset += 1; - - bail!(self(offset, buf_len)); - buf[offset] = MEMBERLIST_DELEGATE_VERSION_BYTE; - offset += 1; - - bail!(self(offset, buf_len)); - buf[offset] = self.memberlist_delegate_version.into(); - offset += 1; - - bail!(self(offset, buf_len)); - buf[offset] = PROTOCOL_VERSION_BYTE; - offset += 1; - - bail!(self(offset, buf_len)); - buf[offset] = self.protocol_version.into(); - offset += 1; - - bail!(self(offset, buf_len)); - buf[offset] = DELEGATE_VERSION_BYTE; - offset += 1; - - #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len()); - - Ok(offset) - } -} From 3cd5984a2fbde5e09b64a3c42942748b65f96231 Mon Sep 17 00:00:00 2001 From: al8n Date: Sat, 1 Mar 2025 00:48:00 +0800 Subject: [PATCH 10/39] a --- serf-core/src/serf/query.rs | 1 - serf-core/src/types/conflict.rs | 7 ------- serf-core/src/types/join.rs | 14 -------------- 3 files changed, 22 deletions(-) diff --git a/serf-core/src/serf/query.rs b/serf-core/src/serf/query.rs index 12b8676..28802e2 100644 --- a/serf-core/src/serf/query.rs +++ b/serf-core/src/serf/query.rs @@ -513,7 +513,6 @@ where return Ok(false); } } - _ => {} } } Ok(true) diff --git a/serf-core/src/types/conflict.rs b/serf-core/src/types/conflict.rs index b89bec4..4ad5ab2 100644 --- a/serf-core/src/types/conflict.rs +++ b/serf-core/src/types/conflict.rs @@ -19,13 +19,6 @@ pub struct ConflictResponseMessage { member: Member, } -impl ConflictResponseMessage { - /// Create a new conflict response message - pub fn new(member: Member) -> Self { - Self { member } - } -} - /// The borrow type of conflict message #[viewit::viewit(setters(prefix = "with"))] #[derive(Debug, PartialEq)] diff --git a/serf-core/src/types/join.rs b/serf-core/src/types/join.rs index 4a0aff1..5c483da 100644 --- a/serf-core/src/types/join.rs +++ b/serf-core/src/types/join.rs @@ -39,20 +39,6 @@ impl JoinMessage { Self { ltime, id } } - /// Set the lamport time - #[inline] - pub fn set_ltime(&mut self, ltime: LamportTime) -> &mut Self { - self.ltime = ltime; - self - } - - /// Set the id of the node - #[inline] - pub fn set_id(&mut self, id: I) -> &mut Self { - self.id = id; - self - } - const fn id_byte() -> u8 where I: Data, From 333d32ee3427058816ffaa075b76e4f9716384b6 Mon Sep 17 00:00:00 2001 From: al8n Date: Sat, 1 Mar 2025 15:41:15 +0800 Subject: [PATCH 11/39] WIP --- Cargo.toml | 3 + serf-core/Cargo.toml | 7 +- serf-core/src/lib.rs | 4 +- serf-core/src/serf.rs | 6 +- serf-core/src/serf/base.rs | 2 +- serf-core/src/types.rs | 10 + serf-core/src/types/arbitrary_impl.rs | 9 +- serf-core/src/types/clock.rs | 11 - serf-core/src/types/conflict.rs | 2 +- serf-core/src/{ => types}/coordinate.rs | 1016 ++++++++++++----------- serf-core/src/types/message.rs | 16 + serf-core/src/types/quickcheck_impl.rs | 311 +++++++ serf-core/src/types/tags.rs | 28 +- serf-core/src/types/tests.rs | 64 ++ 14 files changed, 943 insertions(+), 546 deletions(-) rename serf-core/src/{ => types}/coordinate.rs (98%) create mode 100644 serf-core/src/types/quickcheck_impl.rs create mode 100644 serf-core/src/types/tests.rs diff --git a/Cargo.toml b/Cargo.toml index f2eb25a..252b999 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,9 @@ smol_str = "0.3" smallvec = "1" rand = "0.9" +arbitrary = { version = "1", default-features = false, features = ["derive"] } +quickcheck = { version = "1", default-features = false } + memberlist-proto = { version = "0.1", path = "../memberlist/memberlist-proto", default-features = false } memberlist-core = { version = "0.6", path = "../memberlist/memberlist-core", default-features = false } memberlist = { version = "0.6", path = "../memberlist/memberlist", default-features = false } diff --git a/serf-core/Cargo.toml b/serf-core/Cargo.toml index cc6dbf8..abecc7c 100644 --- a/serf-core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -65,8 +65,8 @@ serde_json = "1" base64 = { version = "0.22", optional = true } -arbitrary = { version = "1", optional = true, default-features = false, features = ["derive"] } -quickcheck = { version = "1", optional = true, default-features = false } +arbitrary = { workspace = true, optional = true, default-features = false, features = ["derive"] } +quickcheck = { workspace = true, optional = true, default-features = false } # test features paste = { version = "1", optional = true } @@ -81,6 +81,9 @@ agnostic-lite = { version = "0.5", features = ["tokio"] } tokio = { version = "1", features = ["full"] } futures = { workspace = true, features = ["executor"] } tempfile = "3" +memberlist-core = { workspace = true, features = ["quickcheck", "arbitrary"] } +quickcheck_macros = "1" +quickcheck.workspace = true [package.metadata.docs.rs] all-features = true diff --git a/serf-core/src/lib.rs b/serf-core/src/lib.rs index 5255732..9737fdb 100644 --- a/serf-core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -10,9 +10,6 @@ pub(crate) mod broadcast; mod coalesce; -/// Coordinate. -pub mod coordinate; - /// Events for [`Serf`] pub mod event; @@ -27,6 +24,7 @@ pub use options::*; /// The types used in `serf`. pub mod types; +pub use types::coordinate; /// Secret key management. #[cfg(feature = "encryption")] diff --git a/serf-core/src/serf.rs b/serf-core/src/serf.rs index ee72a68..4c78811 100644 --- a/serf-core/src/serf.rs +++ b/serf-core/src/serf.rs @@ -17,11 +17,13 @@ use memberlist_core::{ use super::{ Options, broadcast::SerfBroadcast, - coordinate::{Coordinate, CoordinateClient}, delegate::{CompositeDelegate, Delegate}, event::CrateEvent, snapshot::SnapshotHandle, - types::{LamportClock, LamportTime, Members, UserEvents}, + types::{ + LamportClock, LamportTime, Members, UserEvents, + coordinate::{Coordinate, CoordinateClient}, + }, }; mod api; diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index e4797e5..e38eb5f 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -17,7 +17,6 @@ use smol_str::SmolStr; use crate::{ QueueOptions, coalesce::{MemberEventCoalescer, UserEventCoalescer, coalesced_event}, - coordinate::CoordinateOptions, error::Error, event::{InternalQueryEvent, MemberEvent, MemberEventType, QueryContext, QueryEvent}, snapshot::{Snapshot, open_and_replay_snapshot}, @@ -25,6 +24,7 @@ use crate::{ DelegateVersion, Epoch, JoinMessage, LeaveMessage, Member, MemberState, MemberStatus, MemberlistDelegateVersion, MemberlistProtocolVersion, MessageType, NodeIntent, ProtocolVersion, QueryFlag, QueryMessage, QueryResponseMessage, UserEvent, UserEventMessage, + coordinate::CoordinateOptions, }, }; diff --git a/serf-core/src/types.rs b/serf-core/src/types.rs index 0a44972..5bb0720 100644 --- a/serf-core/src/types.rs +++ b/serf-core/src/types.rs @@ -8,9 +8,19 @@ pub use memberlist_core::proto::{ #[cfg(feature = "arbitrary")] mod arbitrary_impl; +#[cfg(feature = "quickcheck")] +mod quickcheck_impl; + +#[cfg(test)] +mod tests; + + mod clock; pub use clock::*; +/// Vivialdi coordinate implementation +pub mod coordinate; + mod conflict; pub(crate) use conflict::*; diff --git a/serf-core/src/types/arbitrary_impl.rs b/serf-core/src/types/arbitrary_impl.rs index 65a4518..197556c 100644 --- a/serf-core/src/types/arbitrary_impl.rs +++ b/serf-core/src/types/arbitrary_impl.rs @@ -3,7 +3,7 @@ use std::{ hash::Hash, }; -use super::{Filter, TagFilter}; +use super::{Filter, MessageType, TagFilter}; use arbitrary::{Arbitrary, Unstructured}; use indexmap::{IndexMap, IndexSet}; use memberlist_core::proto::TinyVec; @@ -158,3 +158,10 @@ where }) } } + +impl<'a> Arbitrary<'a> for MessageType { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + u.arbitrary::() + .map(|val| Self::from(val % Self::ALL.len() as u8)) + } +} diff --git a/serf-core/src/types/clock.rs b/serf-core/src/types/clock.rs index ba26c9b..4233ecd 100644 --- a/serf-core/src/types/clock.rs +++ b/serf-core/src/types/clock.rs @@ -172,17 +172,6 @@ impl LamportClock { } } -#[cfg(feature = "quickcheck")] -const _: () = { - use quickcheck::{Arbitrary, Gen}; - - impl Arbitrary for LamportTime { - fn arbitrary(g: &mut Gen) -> Self { - Self(u64::arbitrary(g)) - } - } -}; - #[test] fn test_lamport_clock() { let l = LamportClock::new(); diff --git a/serf-core/src/types/conflict.rs b/serf-core/src/types/conflict.rs index 4ad5ab2..03fb0b8 100644 --- a/serf-core/src/types/conflict.rs +++ b/serf-core/src/types/conflict.rs @@ -7,7 +7,7 @@ use super::*; /// A conflict message #[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct ConflictResponseMessage { diff --git a/serf-core/src/coordinate.rs b/serf-core/src/types/coordinate.rs similarity index 98% rename from serf-core/src/coordinate.rs rename to serf-core/src/types/coordinate.rs index 7a285df..91dbc72 100644 --- a/serf-core/src/coordinate.rs +++ b/serf-core/src/types/coordinate.rs @@ -1,19 +1,19 @@ +use core::time::Duration; + +use memberlist_core::proto::{ + Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, WireType, + utils::{merge, skip, split}, +}; +use rand::Rng; +use smallvec::SmallVec; + use std::{ collections::HashMap, sync::atomic::{AtomicUsize, Ordering}, - time::Duration, }; -use memberlist_core::{ - CheapClone, - proto::{ - Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, WireType, - utils::{merge, skip, split}, - }, -}; +use memberlist_core::CheapClone; use parking_lot::RwLock; -use rand::Rng; -use smallvec::SmallVec; /// Used to convert float seconds to nanoseconds. const SECONDS_TO_NANOSECONDS: f64 = 1.0e9; @@ -29,20 +29,6 @@ const DEFAULT_ADJUSTMENT_WINDOW_SIZE: usize = 20; const DEFAULT_LATENCY_FILTER_SAMPLES_SIZE: usize = 8; -/// Error type for the [`Coordinate`]. -#[derive(Debug, thiserror::Error, PartialEq, Eq)] -pub enum CoordinateError { - /// Returned when the dimensions of the coordinates are not compatible. - #[error("dimensions aren't compatible")] - DimensionalityMismatch, - /// Returned when the coordinate is invalid. - #[error("invalid coordinate")] - InvalidCoordinate, - /// Returned when the round trip time is not in a valid range. - #[error("round trip time not in valid range, duration {0:?} is not a value less than 10s")] - InvalidRTT(Duration), -} - /// CoordinateOptions is used to set the parameters of the Vivaldi-based coordinate mapping /// algorithm. /// @@ -217,549 +203,265 @@ impl CoordinateOptions { } } -/// Used to record events that occur when updating coordinates. -#[viewit::viewit(setters(prefix = "with"))] -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +/// A specialized structure for holding network coordinates for the +/// Vivaldi-based coordinate mapping algorithm. All of the fields should be public +/// to enable this to be serialized. All values in here are in units of seconds. +#[viewit::viewit(getters(style = "move"), setters(prefix = "with"))] +#[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct CoordinateClientStats { - /// Incremented any time we reset our local coordinate because - /// our calculations have resulted in an invalid state. +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct Coordinate { + /// The Euclidean portion of the coordinate. This is used along + /// with the other fields to provide an overall distance estimate. The + /// units here are seconds. #[viewit( getter( const, - style = "move", - attrs( - doc = "Returns the number of times we reset our local coordinate because our calculations have resulted in an invalid state." - ) + style = "ref", + attrs(doc = "Returns the Euclidean portion of the coordinate.") ), - setter(attrs( - doc = "Sets the number of times we reset our local coordinate because our calculations have resulted in an invalid state." - )) + setter(attrs(doc = "Sets the Euclidean portion of the coordinate.")) )] - resets: usize, + #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::, SmallVec<[f64; DEFAULT_DIMENSIONALITY]>>))] + portion: SmallVec<[f64; DEFAULT_DIMENSIONALITY]>, + /// Reflects the confidence in the given coordinate and is updated + /// dynamically by the Vivaldi Client. This is dimensionless. + #[viewit( + getter(const, attrs(doc = "Returns the confidence in the given coordinate.")), + setter(attrs(doc = "Sets the confidence in the given coordinate.")) + )] + error: f64, + /// A distance offset computed based on a calculation over + /// observations from all other nodes over a fixed window and is updated + /// dynamically by the Vivaldi Client. The units here are seconds. + #[viewit( + getter(const, attrs(doc = "Returns the distance offset.")), + setter(attrs(doc = "Sets the distance offset.")) + )] + adjustment: f64, + /// A distance offset that accounts for non-Euclidean effects + /// which model the access links from nodes to the core Internet. The access + /// links are usually set by bandwidth and congestion, and the core links + /// usually follow distance based on geography. + #[viewit( + getter( + const, + attrs(doc = "Returns the distance offset that accounts for non-Euclidean effects.") + ), + setter(attrs(doc = "Sets the distance offset that accounts for non-Euclidean effects.")) + )] + height: f64, } -impl Default for CoordinateClientStats { +impl Default for Coordinate { #[inline] fn default() -> Self { Self::new() } } -impl CoordinateClientStats { +impl Coordinate { + /// Creates a new coordinate at the origin, using the default options + /// to supply key initial values. #[inline] - const fn new() -> Self { - Self { resets: 0 } + pub fn new() -> Self { + Self::with_options(CoordinateOptions::new()) } -} - -struct CoordinateClientInner { - /// The current estimate of the client's network coordinate. - coord: Coordinate, - - /// Origin is a coordinate sitting at the origin. - origin: Coordinate, - - /// Contains the tuning parameters that govern the performance of - /// the algorithm. - opts: CoordinateOptions, - - /// The current index into the adjustmentSamples slice. - adjustment_index: usize, - - /// Used to store samples for the adjustment calculation. - adjustment_samples: SmallVec<[f64; DEFAULT_ADJUSTMENT_WINDOW_SIZE]>, - /// Used to store the last several RTT samples, - /// keyed by node name. We will use the config's LatencyFilterSamples - /// value to determine how many samples we keep, per node. - latency_filter_samples: HashMap>, -} - -impl CoordinateClientInner -where - I: CheapClone + Eq + core::hash::Hash, -{ - /// Applies a small amount of gravity to pull coordinates towards - /// the center of the coordinate system to combat drift. This assumes that the - /// mutex is locked already. + /// Creates a new coordinate at the origin, using the given options + /// to supply key initial values. #[inline] - fn update_gravity(&mut self) { - let dist = self.origin.distance_to(&self.coord).as_secs(); - let force = -1.0 * f64::powf((dist as f64) / self.opts.gravity_rho, 2.0); - self - .coord - .apply_force_in_place(self.opts.height_min, force, &self.origin); + pub fn with_options(opts: CoordinateOptions) -> Self { + let mut vec = SmallVec::with_capacity(opts.dimensionality); + vec.resize(opts.dimensionality, 0.0); + Self { + portion: vec, + error: opts.vivaldi_error_max, + adjustment: 0.0, + height: opts.height_min, + } } + /// Returns true if the coordinate is valid. #[inline] - fn latency_filter(&mut self, node: &I, rtt_seconds: f64) -> f64 { - let samples = self - .latency_filter_samples - .entry(node.cheap_clone()) - .or_insert_with(|| SmallVec::with_capacity(self.opts.latency_filter_size)); - - // Add the new sample and trim the list, if needed. - samples.push(rtt_seconds); - if samples.len() > self.opts.latency_filter_size { - samples.remove(0); - } - // Sort a copy of the samples and return the median. - let mut tmp = SmallVec::<[f64; DEFAULT_LATENCY_FILTER_SAMPLES_SIZE]>::from_slice(samples); - tmp.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); - tmp[tmp.len() / 2] + pub fn is_valid(&self) -> bool { + self.portion.iter().all(|&f| f.is_finite()) + && self.error.is_finite() + && self.adjustment.is_finite() + && self.height.is_finite() } - /// Updates the Vivaldi portion of the client's coordinate. This - /// assumes that the mutex has been locked already. - fn update_vivaldi(&mut self, other: &Coordinate, mut rtt_seconds: f64) { - const ZERO_THRESHOLD: f64 = 1.0e-6; - - let dist = self.coord.distance_to(other).as_secs_f64(); - rtt_seconds = rtt_seconds.max(ZERO_THRESHOLD); + /// Returns true if the dimensions of the coordinates are compatible. + #[inline] + pub fn is_compatible_with(&self, other: &Self) -> bool { + self.portion.len() == other.portion.len() + } - let wrongness = ((dist - rtt_seconds) / rtt_seconds).abs(); + /// Returns the result of applying the force from the direction of the + /// other coordinate. + pub fn apply_force(&self, height_min: f64, force: f64, other: &Self) -> Self { + assert!( + self.is_compatible_with(other), + "coordinate dimensionality does not match" + ); - let total_error = (self.coord.error + other.error).max(ZERO_THRESHOLD); + let mut ret = self.clone(); + let (mut unit, mag) = unit_vector_at(&self.portion, &other.portion); + add_in_place(&mut ret.portion, mul_in_place(&mut unit, force)); + if mag > ZERO_THRESHOLD { + ret.height = (ret.height + other.height) * force / mag + ret.height; + ret.height = ret.height.max(height_min); + } + ret + } - let weight = self.coord.error / total_error; - self.coord.error = ((self.opts.vivaldi_ce * weight * wrongness) - + (self.coord.error * (1.0 - self.opts.vivaldi_ce * weight))) - .min(self.opts.vivaldi_error_max); + /// Apply the result of applying the force from the direction of the + /// other coordinate to self. + pub fn apply_force_in_place(&mut self, height_min: f64, force: f64, other: &Self) { + assert!( + self.is_compatible_with(other), + "coordinate dimensionality does not match" + ); + let (mut unit, mag) = unit_vector_at(&self.portion, &other.portion); + add_in_place(&mut self.portion, mul_in_place(&mut unit, force)); - let force = self.opts.vivaldi_cc * weight * (rtt_seconds - dist); - self - .coord - .apply_force_in_place(self.opts.height_min, force, other); + if mag > ZERO_THRESHOLD { + self.height = (self.height + other.height) * force / mag + self.height; + self.height = self.height.max(height_min); + } } - /// Updates the adjustment portion of the client's coordinate, if - /// the feature is enabled. This assumes that the mutex has been locked already. - fn update_adjustment(&mut self, other: &Coordinate, rtt_seconds: f64) { - if self.opts.adjustment_window_size == 0 { - return; - } - // Note that the existing adjustment factors don't figure in to this - // calculation so we use the raw distance here. - let dist = self.coord.raw_distance_to(other); - self.adjustment_samples[self.adjustment_index] = rtt_seconds - dist; - self.adjustment_index = (self.adjustment_index + 1) % self.opts.adjustment_window_size; + /// Returns the distance between this coordinate and the other + /// coordinate, including adjustments. + pub fn distance_to(&self, other: &Self) -> Duration { + assert!( + self.is_compatible_with(other), + "coordinate dimensionality does not match" + ); - self.coord.adjustment = - self.adjustment_samples.iter().sum::() / (2.0 * self.opts.adjustment_window_size as f64); + let dist = self.raw_distance_to(other); + let adjusted_dist = dist + self.adjustment + other.adjustment; + let dist = if adjusted_dist > 0.0 { + adjusted_dist + } else { + dist + }; + Duration::from_nanos((dist * SECONDS_TO_NANOSECONDS) as u64) } -} - -/// Manages the estimated network coordinate for a given node, and adjusts -/// it as the node observes round trip times and estimated coordinates from other -/// nodes. The core algorithm is based on Vivaldi, see the documentation for Config -/// for more details. -/// -/// `CoordinateClient` is thread-safe. -// TODO: are there any better ways to avoid using a RwLock? -pub struct CoordinateClient { - inner: RwLock>, - /// Used to record events that occur when updating coordinates. - stats: AtomicUsize, -} -impl Default for CoordinateClient { #[inline] - fn default() -> Self { - Self::new() + pub(crate) fn raw_distance_to(&self, other: &Self) -> f64 { + magnitude_in_place(diff_in_place(&self.portion, &other.portion)) + self.height + other.height } } -impl CoordinateClient { - /// Creates a new client. - #[inline] - pub fn new() -> Self { - Self { - inner: RwLock::new(CoordinateClientInner { - coord: Coordinate::new(), - origin: Coordinate::new(), - opts: CoordinateOptions::new(), - adjustment_index: 0, - adjustment_samples: SmallVec::from_slice(&[0.0; DEFAULT_ADJUSTMENT_WINDOW_SIZE]), - latency_filter_samples: HashMap::new(), - }), - stats: AtomicUsize::new(0), - } - } - - /// Creates a new client with given options. - #[inline] - pub fn with_options(opts: CoordinateOptions) -> Self { - let mut samples = SmallVec::with_capacity(opts.adjustment_window_size); - samples.resize(opts.adjustment_window_size, 0.0); - Self { - inner: RwLock::new(CoordinateClientInner { - coord: Coordinate::with_options(opts.clone()), - origin: Coordinate::with_options(opts.clone()), - opts, - adjustment_index: 0, - adjustment_samples: samples, - latency_filter_samples: HashMap::new(), - }), - stats: AtomicUsize::new(0), - } - } - - /// Returns a copy of the coordinate for this client. - #[inline] - pub fn get_coordinate(&self) -> Coordinate { - self.inner.read().coord.clone() +#[inline] +fn add_in_place(vec1: &mut [f64], vec2: &[f64]) { + for (x, y) in vec1.iter_mut().zip(vec2.iter()) { + *x += y; } +} - /// Forces the client's coordinate to a known state. - #[inline] - pub fn set_coordinate(&self, coord: Coordinate) -> Result<(), CoordinateError> { - let mut l = self.inner.write(); - Self::check_coordinate(&l.coord, &coord).map(|_| l.coord = coord) - } +/// Returns difference between the vec1 and vec2. This assumes the +/// dimensions have already been checked to be compatible. +#[inline] +fn diff(vec1: &[f64], vec2: &[f64]) -> SmallVec<[f64; DEFAULT_DIMENSIONALITY]> { + vec1.iter().zip(vec2).map(|(x, y)| x - y).collect() +} - /// Returns a copy of stats for the client. - #[inline] - pub fn stats(&self) -> CoordinateClientStats { - CoordinateClientStats { - resets: self.stats.load(Ordering::Relaxed), - } - } +/// computes difference between the vec1 and vec2 in place. This assumes the +/// dimensions have already been checked to be compatible. +#[inline] +fn diff_in_place<'a>(vec1: &'a [f64], vec2: &'a [f64]) -> impl Iterator + 'a { + vec1.iter().zip(vec2).map(|(x, y)| x - y) +} - /// Returns the estimated RTT from the client's coordinate to other, the - /// coordinate for another node. - #[inline] - pub fn distance_to(&self, coord: &Coordinate) -> Duration { - self.inner.read().coord.distance_to(coord) +/// multiplied by a scalar factor in place. +#[inline] +fn mul_in_place(vec: &mut [f64], factor: f64) -> &mut [f64] { + for x in vec.iter_mut() { + *x *= factor; } + vec +} - /// Returns an error if the coordinate isn't compatible with - /// this client, or if the coordinate itself isn't valid. This assumes the mutex - /// has been locked already. - #[inline] - fn check_coordinate(this: &Coordinate, coord: &Coordinate) -> Result<(), CoordinateError> { - if !this.is_compatible_with(coord) { - return Err(CoordinateError::DimensionalityMismatch); - } +/// Computes the magnitude of the vec. +#[inline] +fn magnitude_in_place(vec: impl Iterator) -> f64 { + vec.fold(0.0, |acc, x| acc + x * x).sqrt() +} - if !coord.is_valid() { - return Err(CoordinateError::InvalidCoordinate); - } +/// Returns a unit vector pointing at vec1 from vec2. If the two +/// positions are the same then a random unit vector is returned. We also return +/// the distance between the points for use in the later height calculation. +fn unit_vector_at(vec1: &[f64], vec2: &[f64]) -> (SmallVec<[f64; DEFAULT_DIMENSIONALITY]>, f64) { + let mut ret = diff(vec1, vec2); - Ok(()) + let mag = magnitude_in_place(ret.iter().copied()); + if mag > ZERO_THRESHOLD { + mul_in_place(&mut ret, mag.recip()); + return (ret, mag); } -} -impl CoordinateClient -where - I: CheapClone + Eq + core::hash::Hash, -{ - /// Removes any client state for the given node. - #[inline] - pub fn forget_node(&self, node: &I) { - self.inner.write().latency_filter_samples.remove(node); + for x in ret.iter_mut() { + *x = rand_f64() - 0.5; } - /// Update takes other, a coordinate for another node, and rtt, a round trip - /// time observation for a ping to that node, and updates the estimated position of - /// the client's coordinate. Returns the updated coordinate. - pub fn update( - &self, - node: &I, - other: &Coordinate, - rtt: Duration, - ) -> Result { - let mut l = self.inner.write(); - Self::check_coordinate(&l.coord, other)?; - - // The code down below can handle zero RTTs, which we have seen in - // https://github.com/hashicorp/consul/issues/3789, presumably in - // environments with coarse-grained monotonic clocks (we are still - // trying to pin this down). In any event, this is ok from a code PoV - // so we don't need to alert operators with spammy messages. We did - // add a counter so this is still observable, though. - const MAX_RTT: Duration = Duration::from_secs(10); - - if rtt > MAX_RTT { - return Err(CoordinateError::InvalidRTT(rtt)); - } - - #[cfg(feature = "metrics")] - if rtt.is_zero() { - metrics::counter!("serf.coordinate.zero-rtt", l.opts.metric_labels.iter()).increment(1); - } + let mag = magnitude_in_place(ret.iter().copied()); + if mag > ZERO_THRESHOLD { + mul_in_place(&mut ret, mag.recip()); + return (ret, 0.0); + } - let rtt_seconds = l.latency_filter(node, rtt.as_secs_f64()); - l.update_vivaldi(other, rtt_seconds); - l.update_adjustment(other, rtt_seconds); - l.update_gravity(); + // And finally just give up and make a unit vector along the first + // dimension. This should be exceedingly rare. + ret.fill(0.0); + ret[0] = 1.0; + (ret, 0.0) +} - if !l.coord.is_valid() { - self.stats.fetch_add(1, Ordering::Acquire); - l.coord = Coordinate::with_options(l.opts.clone()); +fn rand_f64() -> f64 { + let mut rng = rand::rng(); + loop { + let f = (rng.random_range::(0..(1u64 << 63u64)) as f64) / ((1u64 << 63u64) as f64); + if f == 1.0 { + continue; } - - Ok(l.coord.clone()) + return f; } } -/// A specialized structure for holding network coordinates for the -/// Vivaldi-based coordinate mapping algorithm. All of the fields should be public -/// to enable this to be serialized. All values in here are in units of seconds. -#[viewit::viewit(getters(style = "move"), setters(prefix = "with"))] -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Coordinate { - /// The Euclidean portion of the coordinate. This is used along - /// with the other fields to provide an overall distance estimate. The - /// units here are seconds. - #[viewit( - getter( - const, - style = "ref", - attrs(doc = "Returns the Euclidean portion of the coordinate.") - ), - setter(attrs(doc = "Sets the Euclidean portion of the coordinate.")) - )] - portion: SmallVec<[f64; DEFAULT_DIMENSIONALITY]>, - /// Reflects the confidence in the given coordinate and is updated - /// dynamically by the Vivaldi Client. This is dimensionless. - #[viewit( - getter(const, attrs(doc = "Returns the confidence in the given coordinate.")), - setter(attrs(doc = "Sets the confidence in the given coordinate.")) - )] +const PORTION_TAG: u8 = 1; +const ERROR_TAG: u8 = 2; +const ADJUSTMENT_TAG: u8 = 3; +const HEIGHT_TAG: u8 = 4; +const PORTION_BYTE: u8 = merge(WireType::LengthDelimited, PORTION_TAG); +const ERROR_BYTE: u8 = merge(WireType::Fixed64, ERROR_TAG); +const ADJUSTMENT_BYTE: u8 = merge(WireType::Fixed64, ADJUSTMENT_TAG); +const HEIGHT_BYTE: u8 = merge(WireType::Fixed64, HEIGHT_TAG); + +/// The reference type to [`Coordinate`]. +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct CoordinateRef<'a> { + portion: RepeatedDecoder<'a>, error: f64, - /// A distance offset computed based on a calculation over - /// observations from all other nodes over a fixed window and is updated - /// dynamically by the Vivaldi Client. The units here are seconds. - #[viewit( - getter(const, attrs(doc = "Returns the distance offset.")), - setter(attrs(doc = "Sets the distance offset.")) - )] adjustment: f64, - /// A distance offset that accounts for non-Euclidean effects - /// which model the access links from nodes to the core Internet. The access - /// links are usually set by bandwidth and congestion, and the core links - /// usually follow distance based on geography. - #[viewit( - getter( - const, - attrs(doc = "Returns the distance offset that accounts for non-Euclidean effects.") - ), - setter(attrs(doc = "Sets the distance offset that accounts for non-Euclidean effects.")) - )] height: f64, } -impl Default for Coordinate { - #[inline] - fn default() -> Self { - Self::new() - } -} +impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { + fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> + where + Self: Sized, + { + let mut offset = 0; + let buf_len = buf.len(); -impl Coordinate { - /// Creates a new coordinate at the origin, using the default options - /// to supply key initial values. - #[inline] - pub fn new() -> Self { - Self::with_options(CoordinateOptions::new()) - } - - /// Creates a new coordinate at the origin, using the given options - /// to supply key initial values. - #[inline] - pub fn with_options(opts: CoordinateOptions) -> Self { - let mut vec = SmallVec::with_capacity(opts.dimensionality); - vec.resize(opts.dimensionality, 0.0); - Self { - portion: vec, - error: opts.vivaldi_error_max, - adjustment: 0.0, - height: opts.height_min, - } - } - - /// Returns true if the coordinate is valid. - #[inline] - pub fn is_valid(&self) -> bool { - self.portion.iter().all(|&f| f.is_finite()) - && self.error.is_finite() - && self.adjustment.is_finite() - && self.height.is_finite() - } - - /// Returns true if the dimensions of the coordinates are compatible. - #[inline] - pub fn is_compatible_with(&self, other: &Self) -> bool { - self.portion.len() == other.portion.len() - } - - /// Returns the result of applying the force from the direction of the - /// other coordinate. - pub fn apply_force(&self, height_min: f64, force: f64, other: &Self) -> Self { - assert!( - self.is_compatible_with(other), - "coordinate dimensionality does not match" - ); - - let mut ret = self.clone(); - let (mut unit, mag) = unit_vector_at(&self.portion, &other.portion); - add_in_place(&mut ret.portion, mul_in_place(&mut unit, force)); - if mag > ZERO_THRESHOLD { - ret.height = (ret.height + other.height) * force / mag + ret.height; - ret.height = ret.height.max(height_min); - } - ret - } - - /// Apply the result of applying the force from the direction of the - /// other coordinate to self. - pub fn apply_force_in_place(&mut self, height_min: f64, force: f64, other: &Self) { - assert!( - self.is_compatible_with(other), - "coordinate dimensionality does not match" - ); - let (mut unit, mag) = unit_vector_at(&self.portion, &other.portion); - add_in_place(&mut self.portion, mul_in_place(&mut unit, force)); - - if mag > ZERO_THRESHOLD { - self.height = (self.height + other.height) * force / mag + self.height; - self.height = self.height.max(height_min); - } - } - - /// Returns the distance between this coordinate and the other - /// coordinate, including adjustments. - pub fn distance_to(&self, other: &Self) -> Duration { - assert!( - self.is_compatible_with(other), - "coordinate dimensionality does not match" - ); - - let dist = self.raw_distance_to(other); - let adjusted_dist = dist + self.adjustment + other.adjustment; - let dist = if adjusted_dist > 0.0 { - adjusted_dist - } else { - dist - }; - Duration::from_nanos((dist * SECONDS_TO_NANOSECONDS) as u64) - } - - #[inline] - fn raw_distance_to(&self, other: &Self) -> f64 { - magnitude_in_place(diff_in_place(&self.portion, &other.portion)) + self.height + other.height - } -} - -#[inline] -fn add_in_place(vec1: &mut [f64], vec2: &[f64]) { - for (x, y) in vec1.iter_mut().zip(vec2.iter()) { - *x += y; - } -} - -/// Returns difference between the vec1 and vec2. This assumes the -/// dimensions have already been checked to be compatible. -#[inline] -fn diff(vec1: &[f64], vec2: &[f64]) -> SmallVec<[f64; DEFAULT_DIMENSIONALITY]> { - vec1.iter().zip(vec2).map(|(x, y)| x - y).collect() -} - -/// computes difference between the vec1 and vec2 in place. This assumes the -/// dimensions have already been checked to be compatible. -#[inline] -fn diff_in_place<'a>(vec1: &'a [f64], vec2: &'a [f64]) -> impl Iterator + 'a { - vec1.iter().zip(vec2).map(|(x, y)| x - y) -} - -/// multiplied by a scalar factor in place. -#[inline] -fn mul_in_place(vec: &mut [f64], factor: f64) -> &mut [f64] { - for x in vec.iter_mut() { - *x *= factor; - } - vec -} - -/// Computes the magnitude of the vec. -#[inline] -fn magnitude_in_place(vec: impl Iterator) -> f64 { - vec.fold(0.0, |acc, x| acc + x * x).sqrt() -} - -/// Returns a unit vector pointing at vec1 from vec2. If the two -/// positions are the same then a random unit vector is returned. We also return -/// the distance between the points for use in the later height calculation. -fn unit_vector_at(vec1: &[f64], vec2: &[f64]) -> (SmallVec<[f64; DEFAULT_DIMENSIONALITY]>, f64) { - let mut ret = diff(vec1, vec2); - - let mag = magnitude_in_place(ret.iter().copied()); - if mag > ZERO_THRESHOLD { - mul_in_place(&mut ret, mag.recip()); - return (ret, mag); - } - - for x in ret.iter_mut() { - *x = rand_f64() - 0.5; - } - - let mag = magnitude_in_place(ret.iter().copied()); - if mag > ZERO_THRESHOLD { - mul_in_place(&mut ret, mag.recip()); - return (ret, 0.0); - } - - // And finally just give up and make a unit vector along the first - // dimension. This should be exceedingly rare. - ret.fill(0.0); - ret[0] = 1.0; - (ret, 0.0) -} - -fn rand_f64() -> f64 { - let mut rng = rand::rng(); - loop { - let f = (rng.random_range::(0..(1u64 << 63u64)) as f64) / ((1u64 << 63u64) as f64); - if f == 1.0 { - continue; - } - return f; - } -} - -const PORTION_TAG: u8 = 1; -const ERROR_TAG: u8 = 2; -const ADJUSTMENT_TAG: u8 = 3; -const HEIGHT_TAG: u8 = 4; -const PORTION_BYTE: u8 = merge(WireType::LengthDelimited, PORTION_TAG); -const ERROR_BYTE: u8 = merge(WireType::Fixed64, ERROR_TAG); -const ADJUSTMENT_BYTE: u8 = merge(WireType::Fixed64, ADJUSTMENT_TAG); -const HEIGHT_BYTE: u8 = merge(WireType::Fixed64, HEIGHT_TAG); - -/// The reference type to [`Coordinate`]. -#[derive(Copy, Clone, Debug, PartialEq)] -pub struct CoordinateRef<'a> { - portion: RepeatedDecoder<'a>, - error: f64, - adjustment: f64, - height: f64, -} - -impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { - fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> - where - Self: Sized, - { - let mut offset = 0; - let buf_len = buf.len(); - - let mut portion_offsets = None; - let mut num_portions = 0; - let mut error = None; - let mut adjustment = None; - let mut height = None; + let mut portion_offsets = None; + let mut num_portions = 0; + let mut error = None; + let mut adjustment = None; + let mut height = None; while offset < buf_len { match buf[offset] { @@ -918,6 +620,306 @@ impl Data for Coordinate { } } +/// Error type for the [`Coordinate`]. +#[derive(Debug, thiserror::Error, PartialEq, Eq)] +pub enum CoordinateError { + /// Returned when the dimensions of the coordinates are not compatible. + #[error("dimensions aren't compatible")] + DimensionalityMismatch, + /// Returned when the coordinate is invalid. + #[error("invalid coordinate")] + InvalidCoordinate, + /// Returned when the round trip time is not in a valid range. + #[error("round trip time not in valid range, duration {0:?} is not a value less than 10s")] + InvalidRTT(Duration), +} + +/// Used to record events that occur when updating coordinates. +#[viewit::viewit(setters(prefix = "with"))] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct CoordinateClientStats { + /// Incremented any time we reset our local coordinate because + /// our calculations have resulted in an invalid state. + #[viewit( + getter( + const, + style = "move", + attrs( + doc = "Returns the number of times we reset our local coordinate because our calculations have resulted in an invalid state." + ) + ), + setter(attrs( + doc = "Sets the number of times we reset our local coordinate because our calculations have resulted in an invalid state." + )) + )] + resets: usize, +} + +impl Default for CoordinateClientStats { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl CoordinateClientStats { + #[inline] + const fn new() -> Self { + Self { resets: 0 } + } +} + +struct CoordinateClientInner { + /// The current estimate of the client's network coordinate. + coord: Coordinate, + + /// Origin is a coordinate sitting at the origin. + origin: Coordinate, + + /// Contains the tuning parameters that govern the performance of + /// the algorithm. + opts: CoordinateOptions, + + /// The current index into the adjustmentSamples slice. + adjustment_index: usize, + + /// Used to store samples for the adjustment calculation. + adjustment_samples: SmallVec<[f64; DEFAULT_ADJUSTMENT_WINDOW_SIZE]>, + + /// Used to store the last several RTT samples, + /// keyed by node name. We will use the config's LatencyFilterSamples + /// value to determine how many samples we keep, per node. + latency_filter_samples: HashMap>, +} + +impl CoordinateClientInner +where + I: CheapClone + Eq + core::hash::Hash, +{ + /// Applies a small amount of gravity to pull coordinates towards + /// the center of the coordinate system to combat drift. This assumes that the + /// mutex is locked already. + #[inline] + fn update_gravity(&mut self) { + let dist = self.origin.distance_to(&self.coord).as_secs(); + let force = -1.0 * f64::powf((dist as f64) / self.opts.gravity_rho, 2.0); + self + .coord + .apply_force_in_place(self.opts.height_min, force, &self.origin); + } + + #[inline] + fn latency_filter(&mut self, node: &I, rtt_seconds: f64) -> f64 { + let samples = self + .latency_filter_samples + .entry(node.cheap_clone()) + .or_insert_with(|| SmallVec::with_capacity(self.opts.latency_filter_size)); + + // Add the new sample and trim the list, if needed. + samples.push(rtt_seconds); + if samples.len() > self.opts.latency_filter_size { + samples.remove(0); + } + // Sort a copy of the samples and return the median. + let mut tmp = SmallVec::<[f64; DEFAULT_LATENCY_FILTER_SAMPLES_SIZE]>::from_slice(samples); + tmp.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + tmp[tmp.len() / 2] + } + + /// Updates the Vivaldi portion of the client's coordinate. This + /// assumes that the mutex has been locked already. + fn update_vivaldi(&mut self, other: &Coordinate, mut rtt_seconds: f64) { + const ZERO_THRESHOLD: f64 = 1.0e-6; + + let dist = self.coord.distance_to(other).as_secs_f64(); + rtt_seconds = rtt_seconds.max(ZERO_THRESHOLD); + + let wrongness = ((dist - rtt_seconds) / rtt_seconds).abs(); + + let total_error = (self.coord.error + other.error).max(ZERO_THRESHOLD); + + let weight = self.coord.error / total_error; + self.coord.error = ((self.opts.vivaldi_ce * weight * wrongness) + + (self.coord.error * (1.0 - self.opts.vivaldi_ce * weight))) + .min(self.opts.vivaldi_error_max); + + let force = self.opts.vivaldi_cc * weight * (rtt_seconds - dist); + self + .coord + .apply_force_in_place(self.opts.height_min, force, other); + } + + /// Updates the adjustment portion of the client's coordinate, if + /// the feature is enabled. This assumes that the mutex has been locked already. + fn update_adjustment(&mut self, other: &Coordinate, rtt_seconds: f64) { + if self.opts.adjustment_window_size == 0 { + return; + } + // Note that the existing adjustment factors don't figure in to this + // calculation so we use the raw distance here. + let dist = self.coord.raw_distance_to(other); + self.adjustment_samples[self.adjustment_index] = rtt_seconds - dist; + self.adjustment_index = (self.adjustment_index + 1) % self.opts.adjustment_window_size; + + self.coord.adjustment = + self.adjustment_samples.iter().sum::() / (2.0 * self.opts.adjustment_window_size as f64); + } +} + +/// Manages the estimated network coordinate for a given node, and adjusts +/// it as the node observes round trip times and estimated coordinates from other +/// nodes. The core algorithm is based on Vivaldi, see the documentation for Config +/// for more details. +/// +/// `CoordinateClient` is thread-safe. +// TODO: are there any better ways to avoid using a RwLock? +pub struct CoordinateClient { + inner: RwLock>, + /// Used to record events that occur when updating coordinates. + stats: AtomicUsize, +} + +impl Default for CoordinateClient { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl CoordinateClient { + /// Creates a new client. + #[inline] + pub fn new() -> Self { + Self { + inner: RwLock::new(CoordinateClientInner { + coord: Coordinate::new(), + origin: Coordinate::new(), + opts: CoordinateOptions::new(), + adjustment_index: 0, + adjustment_samples: SmallVec::from_slice(&[0.0; DEFAULT_ADJUSTMENT_WINDOW_SIZE]), + latency_filter_samples: HashMap::new(), + }), + stats: AtomicUsize::new(0), + } + } + + /// Creates a new client with given options. + #[inline] + pub fn with_options(opts: CoordinateOptions) -> Self { + let mut samples = SmallVec::with_capacity(opts.adjustment_window_size); + samples.resize(opts.adjustment_window_size, 0.0); + Self { + inner: RwLock::new(CoordinateClientInner { + coord: Coordinate::with_options(opts.clone()), + origin: Coordinate::with_options(opts.clone()), + opts, + adjustment_index: 0, + adjustment_samples: samples, + latency_filter_samples: HashMap::new(), + }), + stats: AtomicUsize::new(0), + } + } + + /// Returns a copy of the coordinate for this client. + #[inline] + pub fn get_coordinate(&self) -> Coordinate { + self.inner.read().coord.clone() + } + + /// Forces the client's coordinate to a known state. + #[inline] + pub fn set_coordinate(&self, coord: Coordinate) -> Result<(), CoordinateError> { + let mut l = self.inner.write(); + Self::check_coordinate(&l.coord, &coord).map(|_| l.coord = coord) + } + + /// Returns a copy of stats for the client. + #[inline] + pub fn stats(&self) -> CoordinateClientStats { + CoordinateClientStats { + resets: self.stats.load(Ordering::Relaxed), + } + } + + /// Returns the estimated RTT from the client's coordinate to other, the + /// coordinate for another node. + #[inline] + pub fn distance_to(&self, coord: &Coordinate) -> Duration { + self.inner.read().coord.distance_to(coord) + } + + /// Returns an error if the coordinate isn't compatible with + /// this client, or if the coordinate itself isn't valid. This assumes the mutex + /// has been locked already. + #[inline] + fn check_coordinate(this: &Coordinate, coord: &Coordinate) -> Result<(), CoordinateError> { + if !this.is_compatible_with(coord) { + return Err(CoordinateError::DimensionalityMismatch); + } + + if !coord.is_valid() { + return Err(CoordinateError::InvalidCoordinate); + } + + Ok(()) + } +} + +impl CoordinateClient +where + I: CheapClone + Eq + core::hash::Hash, +{ + /// Removes any client state for the given node. + #[inline] + pub fn forget_node(&self, node: &I) { + self.inner.write().latency_filter_samples.remove(node); + } + + /// Update takes other, a coordinate for another node, and rtt, a round trip + /// time observation for a ping to that node, and updates the estimated position of + /// the client's coordinate. Returns the updated coordinate. + pub fn update( + &self, + node: &I, + other: &Coordinate, + rtt: Duration, + ) -> Result { + let mut l = self.inner.write(); + Self::check_coordinate(&l.coord, other)?; + + // The code down below can handle zero RTTs, which we have seen in + // https://github.com/hashicorp/consul/issues/3789, presumably in + // environments with coarse-grained monotonic clocks (we are still + // trying to pin this down). In any event, this is ok from a code PoV + // so we don't need to alert operators with spammy messages. We did + // add a counter so this is still observable, though. + const MAX_RTT: Duration = Duration::from_secs(10); + + if rtt > MAX_RTT { + return Err(CoordinateError::InvalidRTT(rtt)); + } + + #[cfg(feature = "metrics")] + if rtt.is_zero() { + metrics::counter!("serf.coordinate.zero-rtt", l.opts.metric_labels.iter()).increment(1); + } + + let rtt_seconds = l.latency_filter(node, rtt.as_secs_f64()); + l.update_vivaldi(other, rtt_seconds); + l.update_adjustment(other, rtt_seconds); + l.update_gravity(); + + if !l.coord.is_valid() { + self.stats.fetch_add(1, Ordering::Acquire); + l.coord = Coordinate::with_options(l.opts.clone()); + } + + Ok(l.coord.clone()) + } +} + #[cfg(test)] mod tests { use smol_str::SmolStr; diff --git a/serf-core/src/types/message.rs b/serf-core/src/types/message.rs index 9d86d59..1c774bd 100644 --- a/serf-core/src/types/message.rs +++ b/serf-core/src/types/message.rs @@ -86,6 +86,22 @@ pub enum MessageType { } impl MessageType { + /// All message types in order + pub const ALL: &[Self] = &[ + Self::Leave, + Self::Join, + Self::PushPull, + Self::UserEvent, + Self::Query, + Self::QueryResponse, + Self::ConflictResponse, + Self::Relay, + #[cfg(feature = "encryption")] + Self::KeyRequest, + #[cfg(feature = "encryption")] + Self::KeyResponse, + ]; + /// Get the string representation of the message type #[inline] pub fn as_str(&self) -> std::borrow::Cow<'static, str> { diff --git a/serf-core/src/types/quickcheck_impl.rs b/serf-core/src/types/quickcheck_impl.rs new file mode 100644 index 0000000..2cce92d --- /dev/null +++ b/serf-core/src/types/quickcheck_impl.rs @@ -0,0 +1,311 @@ +use core::hash::Hash; + +use quickcheck::{Arbitrary, Gen}; +use smol_str::SmolStr; + +use super::{ + ConflictResponseMessage, DelegateVersion, Filter, JoinMessage, LamportTime, LeaveMessage, Member, + MemberStatus, MessageType, ProtocolVersion, PushPullMessage, QueryFlag, QueryMessage, + QueryResponseMessage, TagFilter, Tags, UserEvent, UserEventMessage, UserEvents, coordinate::Coordinate, +}; + +#[cfg(feature = "encryption")] +use super::{KeyRequestMessage, KeyResponseMessage}; + +impl Arbitrary for ProtocolVersion { + fn arbitrary(g: &mut Gen) -> Self { + ProtocolVersion::from(u8::arbitrary(g)) + } +} + +impl Arbitrary for DelegateVersion { + fn arbitrary(g: &mut Gen) -> Self { + DelegateVersion::from(u8::arbitrary(g)) + } +} + +impl Arbitrary for UserEvent { + fn arbitrary(g: &mut Gen) -> Self { + Self { + name: String::arbitrary(g).into(), + payload: Vec::arbitrary(g).into(), + } + } +} + +impl Arbitrary for UserEvents { + fn arbitrary(g: &mut Gen) -> Self { + Self { + ltime: u64::arbitrary(g).into(), + events: Vec::arbitrary(g).into(), + } + } +} + +impl Arbitrary for UserEventMessage { + fn arbitrary(g: &mut Gen) -> Self { + Self { + ltime: u64::arbitrary(g).into(), + name: String::arbitrary(g).into(), + payload: Vec::arbitrary(g).into(), + cc: bool::arbitrary(g), + } + } +} + +impl Arbitrary for LamportTime { + fn arbitrary(g: &mut Gen) -> Self { + LamportTime::from(u64::arbitrary(g)) + } +} + +impl Arbitrary for TagFilter { + fn arbitrary(g: &mut Gen) -> Self { + Self::new() + .with_tag(String::arbitrary(g).into()) + .maybe_expr(if Arbitrary::arbitrary(g) { + let complexity = *g.choose(&[1, 2, 3, 4, 5]).unwrap(); + let mut patterns = Vec::new(); + + // Basic character classes and quantifiers + let character_classes = vec![ + r"\d", + r"\w", + r"\s", + r"[a-z]", + r"[A-Z]", + r"[0-9]", + r"[a-zA-Z]", + r"[a-zA-Z0-9]", + r".", + ]; + + let quantifiers = vec!["", "*", "+", "?", "{1,3}", "{2,5}"]; + + // Add more complex patterns for higher complexity + let mut extended_classes = character_classes.clone(); + if complexity > 1 { + extended_classes.extend(vec![r"[^a-z]", r"[^0-9]", r"\D", r"\W", r"\S"]); + } + + if complexity > 2 { + // Add a group with random content + let char_class = *g.choose(&extended_classes).unwrap(); + let quantifier = *g.choose(&quantifiers).unwrap(); + patterns.push(format!("({}{})", char_class, quantifier)); + } + + // Generate random pattern parts + for _ in 0..complexity { + let char_class = *g.choose(&extended_classes).unwrap(); + let quantifier = *g.choose(&quantifiers).unwrap(); + patterns.push(format!("{}{}", char_class, quantifier)); + } + + // Maybe add anchors for higher complexity + if complexity > 2 && rand::random_ratio(7, 10) { + if Arbitrary::arbitrary(g) { + patterns.insert(0, "^".to_string()); + } + if Arbitrary::arbitrary(g) { + patterns.push("$".to_string()); + } + } + + // Add alternation for even higher complexity + if complexity > 3 && rand::random_ratio(6, 10) { + let char_class = *g.choose(&extended_classes).unwrap(); + let quantifier = *g.choose(&quantifiers).unwrap(); + patterns.push(format!("|{}{}", char_class, quantifier)); + } + + Some(patterns.join("").try_into().unwrap()) + } else { + None + }) + } +} + +impl Arbitrary for Tags { + fn arbitrary(g: &mut Gen) -> Self { + Self::from_iter( + Vec::<(String, String)>::arbitrary(g) + .into_iter() + .map(|(k, v)| (SmolStr::from(k), SmolStr::from(v))), + ) + } +} + +impl Arbitrary for QueryMessage +where + I: Arbitrary, + A: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + Self { + ltime: Arbitrary::arbitrary(g), + flags: Arbitrary::arbitrary(g), + id: Arbitrary::arbitrary(g), + from: Arbitrary::arbitrary(g), + filters: Vec::arbitrary(g).into(), + relay_factor: Arbitrary::arbitrary(g), + timeout: Arbitrary::arbitrary(g), + name: String::arbitrary(g).into(), + payload: Vec::arbitrary(g).into(), + } + } +} + +impl Arbitrary for QueryFlag { + fn arbitrary(g: &mut Gen) -> Self { + if bool::arbitrary(g) { + QueryFlag::NO_BROADCAST + } else { + QueryFlag::ACK + } + } +} + +impl Arbitrary for QueryResponseMessage +where + I: Arbitrary, + A: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + Self { + ltime: Arbitrary::arbitrary(g), + id: Arbitrary::arbitrary(g), + from: Arbitrary::arbitrary(g), + flags: Arbitrary::arbitrary(g), + payload: Vec::arbitrary(g).into(), + } + } +} + +impl Arbitrary for Filter +where + I: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + if bool::arbitrary(g) { + Filter::Id(Vec::::arbitrary(g).into()) + } else { + Filter::Tag(TagFilter::arbitrary(g)) + } + } +} + +impl Arbitrary for PushPullMessage +where + I: Arbitrary + Hash + Eq, +{ + fn arbitrary(g: &mut Gen) -> Self { + Self { + ltime: Arbitrary::arbitrary(g), + status_ltimes: Vec::<(I, LamportTime)>::arbitrary(g).into_iter().collect(), + left_members: Vec::::arbitrary(g).into_iter().collect(), + event_ltime: Arbitrary::arbitrary(g), + events: Vec::::arbitrary(g).into(), + query_ltime: Arbitrary::arbitrary(g), + } + } +} + +impl Arbitrary for MemberStatus { + fn arbitrary(g: &mut Gen) -> Self { + MemberStatus::from(u8::arbitrary(g)) + } +} + +impl Arbitrary for Member +where + I: Arbitrary, + A: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + Self { + node: Arbitrary::arbitrary(g), + tags: Tags::arbitrary(g).into(), + status: Arbitrary::arbitrary(g), + memberlist_protocol_version: Arbitrary::arbitrary(g), + memberlist_delegate_version: Arbitrary::arbitrary(g), + protocol_version: Arbitrary::arbitrary(g), + delegate_version: Arbitrary::arbitrary(g), + } + } +} + +impl Arbitrary for LeaveMessage +where + I: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + Self { + ltime: Arbitrary::arbitrary(g), + prune: bool::arbitrary(g), + id: Arbitrary::arbitrary(g), + } + } +} + +#[cfg(feature = "encryption")] +impl Arbitrary for KeyRequestMessage { + fn arbitrary(g: &mut Gen) -> Self { + Self { + key: Arbitrary::arbitrary(g), + } + } +} + +#[cfg(feature = "encryption")] +impl Arbitrary for KeyResponseMessage { + fn arbitrary(g: &mut Gen) -> Self { + Self { + result: Arbitrary::arbitrary(g), + message: String::arbitrary(g).into(), + keys: Vec::arbitrary(g).into(), + primary_key: Arbitrary::arbitrary(g), + } + } +} + +impl Arbitrary for JoinMessage +where + I: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + Self { + ltime: Arbitrary::arbitrary(g), + id: Arbitrary::arbitrary(g), + } + } +} + +impl Arbitrary for ConflictResponseMessage +where + I: Arbitrary, + A: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + Self { + member: Arbitrary::arbitrary(g), + } + } +} + +impl Arbitrary for MessageType { + fn arbitrary(g: &mut Gen) -> Self { + MessageType::from(u8::arbitrary(g) % Self::ALL.len() as u8) + } +} + +impl Arbitrary for Coordinate { + fn arbitrary(g: &mut Gen) -> Self { + Self { + portion: Vec::arbitrary(g).into(), + error: Arbitrary::arbitrary(g), + adjustment: Arbitrary::arbitrary(g), + height: Arbitrary::arbitrary(g), + } + } +} diff --git a/serf-core/src/types/tags.rs b/serf-core/src/types/tags.rs index 1725959..d99970f 100644 --- a/serf-core/src/types/tags.rs +++ b/serf-core/src/types/tags.rs @@ -18,6 +18,9 @@ const TAGS_BYTE: u8 = merge(WireType::LengthDelimited, TAGS_TAG); derive_more::Into, derive_more::Deref, derive_more::DerefMut, + derive_more::AsRef, + derive_more::AsMut, + derive_more::IntoIterator, )] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", serde(transparent))] @@ -27,27 +30,16 @@ pub struct Tags( IndexMap, ); -impl IntoIterator for Tags { - type Item = (SmolStr, SmolStr); - type IntoIter = indexmap::map::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl FromIterator<(SmolStr, SmolStr)> for Tags { - fn from_iter>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - -impl<'a> FromIterator<(&'a str, &'a str)> for Tags { - fn from_iter>(iter: T) -> Self { +impl FromIterator<(K, V)> for Tags +where + K: Into, + V: Into, +{ + fn from_iter>(iter: T) -> Self { Self( iter .into_iter() - .map(|(k, v)| (SmolStr::new(k), SmolStr::new(v))) + .map(|(k, v)| (k.into(), v.into())) .collect(), ) } diff --git a/serf-core/src/types/tests.rs b/serf-core/src/types/tests.rs new file mode 100644 index 0000000..4dfdc26 --- /dev/null +++ b/serf-core/src/types/tests.rs @@ -0,0 +1,64 @@ +use std::hash::Hash; + +use quickcheck::{Arbitrary, Gen}; + +use super::*; + +#[derive(Clone, Debug)] +enum Message { + /// Leave message + Leave(LeaveMessage), + /// Join message + Join(JoinMessage), + /// PushPull message + PushPull(PushPullMessage), + /// UserEvent message + UserEvent(UserEventMessage), + /// Query message + Query(QueryMessage), + /// QueryResponse message + QueryResponse(QueryResponseMessage), + /// ConflictResponse message + ConflictResponse(ConflictResponseMessage), + #[cfg(feature = "encryption")] + /// KeyRequest message + KeyRequest(KeyRequestMessage), + #[cfg(feature = "encryption")] + /// KeyResponse message + KeyResponse(KeyResponseMessage), +} + + +impl Arbitrary for Message +where + I: Arbitrary + Hash + Eq, + A: Arbitrary, +{ + fn arbitrary(g: &mut Gen) -> Self { + loop { + let variant = MessageType::arbitrary(g); + + return match variant { + MessageType::ConflictResponse => Message::ConflictResponse(ConflictResponseMessage::arbitrary(g)), + MessageType::Join => Message::Join(JoinMessage::arbitrary(g)), + MessageType::Leave => Message::Leave(LeaveMessage::arbitrary(g)), + MessageType::PushPull => Message::PushPull(PushPullMessage::arbitrary(g)), + MessageType::Query => Message::Query(QueryMessage::arbitrary(g)), + MessageType::QueryResponse => Message::QueryResponse(QueryResponseMessage::arbitrary(g)), + MessageType::UserEvent => Message::UserEvent(UserEventMessage::arbitrary(g)), + #[cfg(feature = "encryption")] + MessageType::KeyRequest => Message::KeyRequest(KeyRequestMessage::arbitrary(g)), + #[cfg(feature = "encryption")] + MessageType::KeyResponse => Message::KeyResponse(KeyResponseMessage::arbitrary(g)), + _ => continue, + }; + } + } +} + + +#[quickcheck_macros::quickcheck] +fn message_arbitrary(msg: Message, node: Option>) -> bool { + todo!() +} + From f8512ff7127d9b1b1da629ae9c2f0a36f0862eb3 Mon Sep 17 00:00:00 2001 From: al8n Date: Sun, 2 Mar 2025 00:50:04 +0800 Subject: [PATCH 12/39] WIP --- serf-core/Cargo.toml | 1 + serf-core/src/event.rs | 2 +- serf-core/src/key_manager.rs | 2 +- serf-core/src/serf/api.rs | 6 +- serf-core/src/serf/base.rs | 10 +- serf-core/src/serf/base/tests.rs | 4 +- .../src/serf/base/tests/serf/delegate.rs | 2 +- serf-core/src/serf/delegate.rs | 2 +- serf-core/src/serf/internal_query.rs | 10 +- serf-core/src/serf/query.rs | 4 +- serf-core/src/types.rs | 9 +- serf-core/src/types/conflict.rs | 2 +- serf-core/src/types/filter.rs | 29 +- serf-core/src/types/filter/tag_filter.rs | 8 +- serf-core/src/types/join.rs | 6 +- serf-core/src/types/key.rs | 2 +- serf-core/src/types/leave.rs | 2 +- serf-core/src/types/member.rs | 4 +- serf-core/src/types/message.rs | 288 +++++++++-------- serf-core/src/types/push_pull.rs | 4 +- serf-core/src/types/query.rs | 10 +- serf-core/src/types/query/response.rs | 9 +- serf-core/src/types/quickcheck_impl.rs | 3 +- serf-core/src/types/tags.rs | 2 + serf-core/src/types/tests.rs | 299 +++++++++++++++++- serf-core/src/types/user_event/message.rs | 2 +- serf-core/src/types/user_event/user_events.rs | 11 +- 27 files changed, 533 insertions(+), 200 deletions(-) diff --git a/serf-core/Cargo.toml b/serf-core/Cargo.toml index abecc7c..596b0b3 100644 --- a/serf-core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -84,6 +84,7 @@ tempfile = "3" memberlist-core = { workspace = true, features = ["quickcheck", "arbitrary"] } quickcheck_macros = "1" quickcheck.workspace = true +paste = "1" [package.metadata.docs.rs] all-features = true diff --git a/serf-core/src/event.rs b/serf-core/src/event.rs index 78fc071..fa1ae2d 100644 --- a/serf-core/src/event.rs +++ b/serf-core/src/event.rs @@ -91,7 +91,7 @@ where flags: QueryFlag::empty(), payload: msg, }; - let buf = crate::types::Encodable::encode_to_bytes(&resp)?; + let buf = crate::types::encode_message_to_bytes(&resp)?; self .respond_with_message_and_response(respond_to, relay_factor, buf, resp) .await diff --git a/serf-core/src/key_manager.rs b/serf-core/src/key_manager.rs index 7bd4689..51683f6 100644 --- a/serf-core/src/key_manager.rs +++ b/serf-core/src/key_manager.rs @@ -181,7 +181,7 @@ where event: InternalQueryEvent, ) -> Result, Error> { let kr = KeyRequestMessage { key }; - let buf = crate::types::Encodable::encode_to_bytes(&kr)?; + let buf = crate::types::encode_message_to_bytes(&kr)?; let serf = self.serf.get().unwrap(); let mut q_param = serf.default_query_param().await; diff --git a/serf-core/src/serf/api.rs b/serf-core/src/serf/api.rs index f8b4924..50487cd 100644 --- a/serf-core/src/serf/api.rs +++ b/serf-core/src/serf/api.rs @@ -268,7 +268,7 @@ where }; // Start broadcasting the event - let len = crate::types::Encodable::encoded_len(&msg); + let len = crate::types::encoded_message_len(&msg); // Check the size after encoding to be sure again that // we're not attempting to send over the specified size limit. @@ -280,7 +280,7 @@ where return Err(Error::raw_user_event_too_large(len)); } - let raw = crate::types::Encodable::encode_to_bytes(&msg)?; + let raw = crate::types::encode_message_to_bytes(&msg)?; self.inner.event_clock.increment(); @@ -459,7 +459,7 @@ where // other node alive. if self.has_alive_members().await { let (notify_tx, notify_rx) = async_channel::bounded(1); - let msg = crate::types::Encodable::encode_to_bytes(&msg)?; + let msg = crate::types::encode_message_to_bytes(&msg)?; self.broadcast(msg, Some(notify_tx)).await?; futures::select! { diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index e38eb5f..fc910e7 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -386,7 +386,7 @@ where // Process update locally self.handle_node_join_intent(&msg).await; - let msg = crate::types::Encodable::encode_to_bytes(&msg)?; + let msg = crate::types::encode_message_to_bytes(&msg)?; // Start broadcasting the update if let Err(e) = self.broadcast(msg, None).await { tracing::warn!(err=%e, "serf: failed to broadcast join intent"); @@ -471,7 +471,7 @@ where return Ok(()); } - let msg = crate::types::Encodable::encode_to_bytes(&msg)?; + let msg = crate::types::encode_message_to_bytes(&msg)?; // Broadcast the remove let (ntx, nrx) = async_channel::bounded(1); self.broadcast(msg, Some(ntx)).await?; @@ -917,14 +917,14 @@ where }; // Encode the query - let len = crate::types::Encodable::encoded_len(&q); + let len = crate::types::encoded_message_len(&q); // Check the size if len > self.inner.opts.query_size_limit { return Err(Error::query_too_large(len)); } - let raw = crate::types::Encodable::encode_to_bytes(&q)?; + let raw = crate::types::encode_message_to_bytes(&q)?; // Register QueryResponse to track acks and responses let resp = QueryResponse::from_query(&q, self.inner.memberlist.num_online_members().await); @@ -1086,7 +1086,7 @@ where payload: Bytes::new(), }; - match crate::types::Encodable::encode_to_bytes(&ack) { + match crate::types::encode_message_to_bytes(&ack) { Ok(raw) => { let (name, payload, from, relay_factor) = match q { Either::Left(q) => ( diff --git a/serf-core/src/serf/base/tests.rs b/serf-core/src/serf/base/tests.rs index 5d7c71c..658917a 100644 --- a/serf-core/src/serf/base/tests.rs +++ b/serf-core/src/serf/base/tests.rs @@ -354,10 +354,10 @@ pub async fn estimate_max_keys_in_list_key_response_factor( let mut found = 0; for i in (0..=resp.keys.len()).rev() { - let dst = crate::types::Encodable::encode_to_bytes(&resp).unwrap(); + let dst = crate::types::encode_message_to_bytes(&resp).unwrap(); let qresp = query.create_response(dst); - let dst = crate::types::Encodable::encode_to_bytes(&qresp).unwrap(); + let dst = crate::types::encode_message_to_bytes(&qresp).unwrap(); if query.check_response_size(dst.len()).is_err() { resp.keys.truncate(i); continue; diff --git a/serf-core/src/serf/base/tests/serf/delegate.rs b/serf-core/src/serf/base/tests/serf/delegate.rs index cdf779a..ab9e0a5 100644 --- a/serf-core/src/serf/base/tests/serf/delegate.rs +++ b/serf-core/src/serf/base/tests/serf/delegate.rs @@ -148,7 +148,7 @@ where query_ltime: 100.into(), }; - let buf = crate::types::Encodable::encode_to_bytes(&pp).unwrap(); + let buf = crate::types::encode_message_to_bytes(&pp).unwrap(); // Merge in fake state d.merge_remote_state(&buf, false).await; diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index 7246cd7..f750da1 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -404,7 +404,7 @@ where }; drop(members); - match crate::types::Encodable::encode_to_bytes(&pp) { + match crate::types::encode_message_to_bytes(&pp) { Ok(buf) => buf, Err(e) => { tracing::error!(err=%e, "serf: failed to encode local state"); diff --git a/serf-core/src/serf/internal_query.rs b/serf-core/src/serf/internal_query.rs index 9e2a23c..b8ec23f 100644 --- a/serf-core/src/serf/internal_query.rs +++ b/serf-core/src/serf/internal_query.rs @@ -149,7 +149,7 @@ where match out { Some(state) => { let resp = crate::types::ConflictResponseMessageBorrow::from(state.member()); - match crate::types::Encodable::encode_to_bytes(&resp) { + match crate::types::encode_message_to_bytes(&resp) { Ok(raw) => { if let Err(e) = ev.respond(raw).await { tracing::error!(target="serf", err=%e, "failed to respond to conflict query"); @@ -426,12 +426,12 @@ where (q.ctx.this.inner.opts.query_response_size_limit / MIN_ENCODED_KEY_LENGTH).min(actual); for i in (0..=max_list_keys).rev() { - let kraw = crate::types::Encodable::encode_to_bytes(&*resp)?; + let kraw = crate::types::encode_message_to_bytes(&*resp)?; // create response let qresp = q.create_response(kraw.clone()); - let encoded_len = crate::types::Encodable::encoded_len(&qresp); + let encoded_len = crate::types::encoded_message_len(&qresp); // Check the size limit if q.check_response_size(encoded_len).is_err() { resp.keys.drain(i..); @@ -443,7 +443,7 @@ where } // encode response - let qraw = crate::types::Encodable::encode_to_bytes(&qresp)?; + let qraw = crate::types::encode_message_to_bytes(&qresp)?; if actual > i { tracing::warn!("serf: {}", resp.message); @@ -469,7 +469,7 @@ where tracing::error!(target="serf", err=%e, "failed to respond to key query"); } } - _ => match crate::types::Encodable::encode_to_bytes(&*resp) { + _ => match crate::types::encode_message_to_bytes(&*resp) { Ok(raw) => { if let Err(e) = q.respond(raw).await { tracing::error!(target="serf", err=%e, "failed to respond to key query"); diff --git a/serf-core/src/serf/query.rs b/serf-core/src/serf/query.rs index 28802e2..a50ba02 100644 --- a/serf-core/src/serf/query.rs +++ b/serf-core/src/serf/query.rs @@ -556,14 +556,14 @@ where } // Prep the relay message, which is a wrapped version of the original. - let encoded_len = crate::types::Encodable::encoded_len_with_relay(&resp, &node); + let encoded_len = crate::types::encoded_relay_message_len(&resp, &node); if encoded_len > self.inner.opts.query_response_size_limit { return Err(Error::relayed_response_too_large( self.inner.opts.query_response_size_limit, )); } - let raw = crate::types::Encodable::encode_relay_to_bytes(&resp, &node)?; + let raw = crate::types::encode_relay_message_to_bytes(&resp, &node)?; // Relay to a random set of peers. let relay_members = random_members(relay_factor as usize, members); diff --git a/serf-core/src/types.rs b/serf-core/src/types.rs index 5bb0720..6f0be3e 100644 --- a/serf-core/src/types.rs +++ b/serf-core/src/types.rs @@ -14,7 +14,6 @@ mod quickcheck_impl; #[cfg(test)] mod tests; - mod clock; pub use clock::*; @@ -62,10 +61,12 @@ pub use key::*; #[cfg(debug_assertions)] #[inline] -fn debug_assert_write_eq(actual: usize, expected: usize) { +fn debug_assert_write_eq(actual: usize, expected: usize) { debug_assert_eq!( - actual, expected, - "expect writting {expected} bytes, but actual write {actual} bytes" + actual, + expected, + "{}: expect writting {expected} bytes, but actual write {actual} bytes", + core::any::type_name::(), ); } diff --git a/serf-core/src/types/conflict.rs b/serf-core/src/types/conflict.rs index 03fb0b8..5bb686c 100644 --- a/serf-core/src/types/conflict.rs +++ b/serf-core/src/types/conflict.rs @@ -77,7 +77,7 @@ where .map_err(|e| e.update(self.encoded_len_in(), buf.len()))?; #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len_in()); + super::debug_assert_write_eq::(offset, self.encoded_len_in()); Ok(offset) } diff --git a/serf-core/src/types/filter.rs b/serf-core/src/types/filter.rs index 52f38dd..376992c 100644 --- a/serf-core/src/types/filter.rs +++ b/serf-core/src/types/filter.rs @@ -99,10 +99,6 @@ where Self: Sized, { let buf_len = buf.len(); - if buf_len < 1 { - return Err(DecodeError::buffer_underflow()); - } - let mut offset = 0; let mut ids_offsets = None; let mut num_ids = 0; @@ -111,7 +107,8 @@ where while offset < buf_len { match buf[offset] { val if val == Filter::::id_byte() => { - let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + offset += 1; + let readed = skip(I::WIRE_TYPE, &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = ids_offsets { if *fnso > offset { *fnso = offset - 1; @@ -160,12 +157,15 @@ where tag } else if let Some((start, end)) = ids_offsets { Self::Id( - RepeatedDecoder::new(FILTER_ID_TAG, WireType::LengthDelimited, buf) + RepeatedDecoder::new(FILTER_ID_TAG, I::WIRE_TYPE, buf) .with_nums(num_ids) .with_offsets(start, end), ) } else { - return Err(DecodeError::missing_field("Filter", "value")); + Self::Id( + RepeatedDecoder::new(FILTER_ID_TAG, I::WIRE_TYPE, buf) + .with_nums(0), + ) }, )) } @@ -205,14 +205,13 @@ where } fn encoded_len(&self) -> usize { - 1usize - + match self { - Filter::Id(ids) => ids - .iter() - .map(|id| 1 + id.encoded_len_with_length_delimited()) - .sum::(), - Filter::Tag(tag) => 1 + tag.encoded_len_with_length_delimited(), - } + match self { + Filter::Id(ids) => ids + .iter() + .map(|id| 1 + id.encoded_len_with_length_delimited()) + .sum::(), + Filter::Tag(tag) => 1 + tag.encoded_len_with_length_delimited(), + } } fn encode(&self, buf: &mut [u8]) -> Result { diff --git a/serf-core/src/types/filter/tag_filter.rs b/serf-core/src/types/filter/tag_filter.rs index e566800..c2ecc72 100644 --- a/serf-core/src/types/filter/tag_filter.rs +++ b/serf-core/src/types/filter/tag_filter.rs @@ -38,7 +38,8 @@ impl<'a> DataRef<'a, TagFilter> for TagFilterRef<'a> { } offset += 1; - let (read, value) = <&str as DataRef<'_, SmolStr>>::decode(&src[offset..])?; + let (read, value) = + <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; offset += read; tag = Some(value); } @@ -48,7 +49,8 @@ impl<'a> DataRef<'a, TagFilter> for TagFilterRef<'a> { } offset += 1; - let (read, value) = <&str as DataRef<'_, SmolStr>>::decode(&src[offset..])?; + let (read, value) = + <&str as DataRef<'_, SmolStr>>::decode_length_delimited(&src[offset..])?; offset += read; expr = Some(value); } @@ -191,7 +193,7 @@ impl Data for TagFilter { } #[cfg(debug_assertions)] - super::super::debug_assert_write_eq(offset, self.encoded_len()); + super::super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } diff --git a/serf-core/src/types/join.rs b/serf-core/src/types/join.rs index 5c483da..d00fd54 100644 --- a/serf-core/src/types/join.rs +++ b/serf-core/src/types/join.rs @@ -75,7 +75,7 @@ where offset += read; ltime = Some(value); } - ID_TAG => { + b if b == JoinMessage::::id_byte() => { if id.is_some() { return Err(DecodeError::duplicate_field("JoinMessage", "id", ID_TAG)); } @@ -139,7 +139,7 @@ where buf[offset] = LTIME_BYTE; offset += 1; - offset += self.ltime.encode(buf)?; + offset += self.ltime.encode(&mut buf[offset..])?; if buf_len <= offset { return Err(EncodeError::insufficient_buffer( @@ -154,7 +154,7 @@ where offset += self.id.encode_length_delimited(&mut buf[offset..])?; #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len()); + super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } diff --git a/serf-core/src/types/key.rs b/serf-core/src/types/key.rs index be1705a..4247e51 100644 --- a/serf-core/src/types/key.rs +++ b/serf-core/src/types/key.rs @@ -373,7 +373,7 @@ impl Data for KeyResponseMessage { } #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len()); + super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } diff --git a/serf-core/src/types/leave.rs b/serf-core/src/types/leave.rs index 42e3d9e..2fdeee0 100644 --- a/serf-core/src/types/leave.rs +++ b/serf-core/src/types/leave.rs @@ -195,7 +195,7 @@ where .map_err(|e| e.update(self.encoded_len(), buf_len))?; #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len()); + super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } diff --git a/serf-core/src/types/member.rs b/serf-core/src/types/member.rs index a8c174b..b090ae8 100644 --- a/serf-core/src/types/member.rs +++ b/serf-core/src/types/member.rs @@ -492,9 +492,11 @@ where bail!(self(offset, buf_len)); buf[offset] = DELEGATE_VERSION_BYTE; offset += 1; + buf[offset] = self.delegate_version.into(); + offset += 1; #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len()); + super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } diff --git a/serf-core/src/types/message.rs b/serf-core/src/types/message.rs index 1c774bd..5d10d12 100644 --- a/serf-core/src/types/message.rs +++ b/serf-core/src/types/message.rs @@ -167,7 +167,15 @@ macro_rules! bail { ($this:ident($offset:expr, $len:ident)) => { if $offset >= $len { return Err(EncodeError::insufficient_buffer( - Encodable::encoded_len($this), + encoded_message_len($this), + $len, + )); + } + }; + (@relay $this:ident($offset:expr, $len:ident, $node:ident)) => { + if $offset >= $len { + return Err(EncodeError::insufficient_buffer( + encoded_relay_message_len($this, $node), $len, )); } @@ -182,94 +190,18 @@ const RELAY_MSG_BYTE: u8 = merge(WireType::LengthDelimited, RELAY_MSG_TAG); /// A trait for encoding messages. pub trait Encodable { + const ID: u8; + /// Encodes the message into a buffer. fn encode(&self, buf: &mut [u8]) -> Result; - /// Encodes a relay message into a buffer. - fn encode_relay(&self, node: &Node, buf: &mut [u8]) -> Result - where - I: Data, - A: Data, - { - let mut offset = 0; - let buf_len = buf.len(); - - if offset >= buf_len { - return Err(EncodeError::insufficient_buffer( - self.encoded_len_with_relay(node), - buf_len, - )); - } - - buf[offset] = RELAY_MESSAGE_BYTE; - offset += 1; - - if offset >= buf_len { - return Err(EncodeError::insufficient_buffer( - self.encoded_len_with_relay(node), - buf_len, - )); - } - - buf[offset] = RELAY_NODE_BYTE; - offset += 1; - - offset += node - .encode_length_delimited(&mut buf[offset..]) - .map_err(|e| e.update(self.encoded_len_with_relay(node), buf_len))?; - - if offset >= buf_len { - return Err(EncodeError::insufficient_buffer( - self.encoded_len_with_relay(node), - buf_len, - )); - } - - buf[offset] = RELAY_MSG_BYTE; - offset += 1; - - offset += self - .encode(&mut buf[offset..]) - .map_err(|e| e.update(self.encoded_len_with_relay(node), buf_len))?; - - #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len_with_relay(node)); - - Ok(offset) - } - - /// Encodes the message into a [`Bytes`]. - fn encode_to_bytes(&self) -> Result { - let len = self.encoded_len(); - let mut buf = vec![0; len]; - self.encode(&mut buf).map(|_| Bytes::from(buf)) - } - - /// Encodes a relay message into a [`Bytes`]. - fn encode_relay_to_bytes(&self, node: &Node) -> Result - where - I: Data, - A: Data, - { - let len = self.encoded_len_with_relay(node); - let mut buf = vec![0; len]; - self.encode_relay(node, &mut buf).map(|_| Bytes::from(buf)) - } - /// Returns the encoded length of the message. fn encoded_len(&self) -> usize; - - /// Returns the encoded length of the message with a relay tag. - fn encoded_len_with_relay(&self, node: &Node) -> usize - where - I: Data, - A: Data, - { - 1 + node.encoded_len_with_length_delimited() + 1 + self.encoded_len() - } } impl Encodable for &T { + const ID: u8 = T::ID; + fn encode(&self, buf: &mut [u8]) -> Result { (*self).encode(buf) } @@ -294,24 +226,14 @@ macro_rules! impl_encodable { $($generic: Data,)+ )? { - fn encode(&self, buf: &mut [u8]) -> Result { - let mut offset = 0; - let buf_len = buf.len(); - bail!(self(offset, buf_len)); - - buf[offset] = $id; - offset += 1; - - offset += self.encode_length_delimited(&mut buf[offset..])?; - - #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, Encodable::encoded_len(self)); + const ID: u8 = $id; - Ok(offset) + fn encode(&self, buf: &mut [u8]) -> Result { + Data::encode(self, buf) } fn encoded_len(&self) -> usize { - 1 + self.encoded_len_with_length_delimited() + Data::encoded_len(self) } } )* @@ -321,7 +243,6 @@ macro_rules! impl_encodable { impl_encodable!( LeaveMessage = LEAVE_MESSAGE_BYTE, JoinMessage = JOIN_MESSAGE_BYTE, - // PushPullMessage = PUSH_PULL_MESSAGE_BYTE, UserEventMessage = USER_EVENT_MESSAGE_BYTE, QueryMessage = QUERY_MESSAGE_BYTE, QueryResponseMessage = QUERY_RESPONSE_MESSAGE_BYTE, @@ -337,24 +258,29 @@ where I: Data, A: Data, { - fn encode(&self, buf: &mut [u8]) -> Result { - let mut offset = 0; - let buf_len = buf.len(); - bail!(self(offset, buf_len)); + const ID: u8 = CONFLICT_RESPONSE_MESSAGE_BYTE; - buf[offset] = CONFLICT_RESPONSE_MESSAGE_BYTE; - offset += 1; + fn encode(&self, buf: &mut [u8]) -> Result { + self.encode_in(buf) + } - offset += self.encode_in(&mut buf[offset..])?; + fn encoded_len(&self) -> usize { + self.encoded_len_in() + } +} - #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, Encodable::encoded_len(self)); +impl super::Encodable for PushPullMessage +where + I: Data + Eq + core::hash::Hash, +{ + const ID: u8 = PUSH_PULL_MESSAGE_BYTE; - Ok(offset) + fn encode(&self, buf: &mut [u8]) -> Result { + Data::encode(self, buf) } fn encoded_len(&self) -> usize { - 1 + self.encoded_len_in() + Data::encoded_len(self) } } @@ -362,24 +288,14 @@ impl super::Encodable for PushPullMessageBorrow<'_, I> where I: Data, { - fn encode(&self, buf: &mut [u8]) -> Result { - let mut offset = 0; - let buf_len = buf.len(); - bail!(self(offset, buf_len)); - - buf[offset] = PUSH_PULL_MESSAGE_BYTE; - offset += 1; - - offset += self.encode_in(&mut buf[offset..])?; + const ID: u8 = PUSH_PULL_MESSAGE_BYTE; - #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, Encodable::encoded_len(self)); - - Ok(offset) + fn encode(&self, buf: &mut [u8]) -> Result { + self.encode_in(buf) } fn encoded_len(&self) -> usize { - 1 + self.encoded_len_in() + self.encoded_len_in() } } @@ -437,6 +353,127 @@ impl MessageRef<'_, I, A> { } } +/// Encode a message into a Bytes. +pub fn encode_message_to_bytes(msg: &T) -> Result +where + T: Encodable, +{ + let len = encoded_message_len(msg); + let mut buf = vec![0; len]; + encode_message(msg, &mut buf).map(|_| Bytes::from(buf)) +} + +/// Encode a relay message into a Bytes. +pub fn encode_relay_message_to_bytes(msg: &T, node: &Node) -> Result +where + T: Encodable, + I: Data, + A: Data, +{ + let len = encoded_relay_message_len(msg, node); + let mut buf = vec![0; len]; + encode_relay_message(msg, node, &mut buf).map(|_| Bytes::from(buf)) +} + +/// Encode a message into a buffer. +pub fn encode_message(msg: &T, buf: &mut [u8]) -> Result +where + T: Encodable, +{ + let mut offset = 0; + let buf_len = buf.len(); + bail!(msg(offset, buf_len)); + + buf[offset] = T::ID; + offset += 1; + + let encoded_len = msg.encoded_len(); + if encoded_len > u32::MAX as usize { + return Err(EncodeError::TooLarge); + } + + offset += (encoded_len as u32).encode(&mut buf[offset..]).map_err(|e| e.update(encoded_message_len(msg), buf_len))?; + + offset += msg.encode(&mut buf[offset..]).map_err(|e| e.update(encoded_message_len(msg), buf_len))?; + + #[cfg(debug_assertions)] + { + struct Message(core::marker::PhantomData); + super::debug_assert_write_eq::>(offset, encoded_message_len(msg)); + } + + Ok(offset) +} + +/// Encode a relay message into a buffer. +pub fn encode_relay_message(msg: &T, node: &Node, buf: &mut [u8]) -> Result +where + T: Encodable, + I: Data, + A: Data, +{ + let mut offset = 0; + let buf_len = buf.len(); + bail!(@relay msg(offset, buf_len, node)); + + buf[offset] = RELAY_MESSAGE_BYTE; + offset += 1; + + bail!(@relay msg(offset, buf_len, node)); + buf[offset] = RELAY_NODE_BYTE; + offset += 1; + offset += node + .encode_length_delimited(&mut buf[offset..]) + .map_err(|e| e.update(encoded_relay_message_len(msg, node), buf_len))?; + + bail!(@relay msg(offset, buf_len, node)); + buf[offset] = RELAY_MSG_BYTE; + offset += 1; + + bail!(@relay msg(offset, buf_len, node)); + buf[offset] = T::ID; + offset += 1; + + let encoded_len = msg.encoded_len(); + if encoded_len > u32::MAX as usize { + return Err(EncodeError::TooLarge); + } + + offset += (encoded_len as u32).encode(&mut buf[offset..]).map_err(|e| e.update(encoded_relay_message_len(msg, node), buf_len))?; + offset += msg.encode(&mut buf[offset..]).map_err(|e| e.update(encoded_relay_message_len(msg, node), buf_len))?; + + + #[cfg(debug_assertions)] + { + struct Message(core::marker::PhantomData); + super::debug_assert_write_eq::>(offset, encoded_relay_message_len(msg, node)); + } + + Ok(offset) +} + +/// Returns the encoded length of a message. +pub fn encoded_message_len(msg: &T) -> usize +where + T: Encodable, +{ + let encoded_len = msg.encoded_len(); + 1 + (encoded_len as u32).encoded_len() + encoded_len +} + +/// Returns the encoded length of the relay message. +pub fn encoded_relay_message_len(msg: &T, node: &Node) -> usize +where + T: Encodable, + I: Data, + A: Data, +{ + 1 + 1 + node.encoded_len_with_length_delimited() + 1 + { + let encoded_len = msg.encoded_len(); + 1 + (encoded_len as u32).encoded_len() + encoded_len + } +} + /// Decode a message from a buffer. pub fn decode_message( buf: &[u8], @@ -674,15 +711,18 @@ where offset += 1; // Skip length-delimited field by reading the length and skipping the payload - if buf[offset..].is_empty() { + if buf[offset..].len() < 2 { return Err(DecodeError::buffer_underflow()); } + let start_offset = offset; + let _ = buf[offset]; + offset += 1; + let (read, length) = ::decode(&buf[offset..])?; offset += read; - - msg = Some(&buf[offset..offset + length as usize]); offset += length as usize; + msg = Some(&buf[start_offset..offset]); } other => { offset += 1; diff --git a/serf-core/src/types/push_pull.rs b/serf-core/src/types/push_pull.rs index 4383a86..08640a3 100644 --- a/serf-core/src/types/push_pull.rs +++ b/serf-core/src/types/push_pull.rs @@ -450,7 +450,7 @@ where offset += self.query_ltime.encode(&mut buf[offset..])?; #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len()); + super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } @@ -595,7 +595,7 @@ where offset += self.query_ltime.encode(&mut buf[offset..])?; #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len_in()); + super::debug_assert_write_eq::(offset, self.encoded_len_in()); Ok(offset) } diff --git a/serf-core/src/types/query.rs b/serf-core/src/types/query.rs index f4935c0..8812b70 100644 --- a/serf-core/src/types/query.rs +++ b/serf-core/src/types/query.rs @@ -451,7 +451,6 @@ where bail!(self(offset, buf_len)); buf[offset] = LTIME_BYTE; offset += 1; - offset += self .ltime .encode(&mut buf[offset..]) @@ -460,7 +459,6 @@ where bail!(self(offset, buf_len)); buf[offset] = ID_BYTE; offset += 1; - offset += self .id .encode(&mut buf[offset..]) @@ -469,7 +467,6 @@ where bail!(self(offset, buf_len)); buf[offset] = FROM_BYTE; offset += 1; - offset += self .from .encode_length_delimited(&mut buf[offset..]) @@ -488,14 +485,12 @@ where bail!(self(offset, buf_len)); buf[offset] = FLAGS_BYTE; offset += 1; - offset += ::encode(&self.flags.bits(), &mut buf[offset..]) .map_err(|e| e.update(self.encoded_len(), buf_len))?; bail!(self(offset, buf_len)); buf[offset] = RELAY_FACTOR_BYTE; offset += 1; - bail!(self(offset, buf_len)); buf[offset] = self.relay_factor; offset += 1; @@ -503,7 +498,6 @@ where bail!(self(offset, buf_len)); buf[offset] = TIMEOUT_BYTE; offset += 1; - offset += self .timeout .encode(&mut buf[offset..]) @@ -513,7 +507,6 @@ where bail!(self(offset, buf_len)); buf[offset] = NAME_BYTE; offset += 1; - offset += self .name .encode_length_delimited(&mut buf[offset..]) @@ -524,7 +517,6 @@ where bail!(self(offset, buf_len)); buf[offset] = PAYLOAD_BYTE; offset += 1; - offset += self .payload .encode_length_delimited(&mut buf[offset..]) @@ -532,7 +524,7 @@ where } #[cfg(debug_assertions)] - super::debug_assert_write_eq(offset, self.encoded_len()); + super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } diff --git a/serf-core/src/types/query/response.rs b/serf-core/src/types/query/response.rs index 999ef30..f2e9e99 100644 --- a/serf-core/src/types/query/response.rs +++ b/serf-core/src/types/query/response.rs @@ -159,7 +159,7 @@ where offset += 1; let (o, v) = - , A::Ref<'_>> as DataRef<'_, Node>>::decode(&buf[offset..])?; + , A::Ref<'_>> as DataRef<'_, Node>>::decode_length_delimited(&buf[offset..])?; offset += o; from = Some(v); } @@ -187,7 +187,7 @@ where } offset += 1; - let (o, v) = <&[u8] as DataRef<'_, Bytes>>::decode(&buf[offset..])?; + let (o, v) = <&[u8] as DataRef<'_, Bytes>>::decode_length_delimited(&buf[offset..])?; offset += o; payload = Some(v); } @@ -262,7 +262,6 @@ where bail!(self(offset, buf_len)); buf[offset] = LTIME_BYTE; offset += 1; - offset += self .ltime .encode(&mut buf[offset..]) @@ -271,7 +270,6 @@ where bail!(self(offset, buf_len)); buf[offset] = ID_BYTE; offset += 1; - offset += self .id .encode(&mut buf[offset..]) @@ -305,8 +303,7 @@ where } #[cfg(debug_assertions)] - super::super::debug_assert_write_eq(offset, self.encoded_len()); - + super::super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } } diff --git a/serf-core/src/types/quickcheck_impl.rs b/serf-core/src/types/quickcheck_impl.rs index 2cce92d..953131a 100644 --- a/serf-core/src/types/quickcheck_impl.rs +++ b/serf-core/src/types/quickcheck_impl.rs @@ -6,7 +6,8 @@ use smol_str::SmolStr; use super::{ ConflictResponseMessage, DelegateVersion, Filter, JoinMessage, LamportTime, LeaveMessage, Member, MemberStatus, MessageType, ProtocolVersion, PushPullMessage, QueryFlag, QueryMessage, - QueryResponseMessage, TagFilter, Tags, UserEvent, UserEventMessage, UserEvents, coordinate::Coordinate, + QueryResponseMessage, TagFilter, Tags, UserEvent, UserEventMessage, UserEvents, + coordinate::Coordinate, }; #[cfg(feature = "encryption")] diff --git a/serf-core/src/types/tags.rs b/serf-core/src/types/tags.rs index d99970f..cae7a42 100644 --- a/serf-core/src/types/tags.rs +++ b/serf-core/src/types/tags.rs @@ -78,6 +78,8 @@ impl<'a> DataRef<'a, Tags> for TagsRef<'a> { while offset < buf_len { match src[offset] { TAGS_BYTE => { + offset += 1; + let readed = skip(WireType::LengthDelimited, &src[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = tags_offsets { if *fnso > offset { diff --git a/serf-core/src/types/tests.rs b/serf-core/src/types/tests.rs index 4dfdc26..d81cf0a 100644 --- a/serf-core/src/types/tests.rs +++ b/serf-core/src/types/tests.rs @@ -1,9 +1,109 @@ use std::hash::Hash; +use memberlist_core::{ + bytes::Bytes, + proto::{Data, DataRef}, +}; use quickcheck::{Arbitrary, Gen}; use super::*; +fn data_round_trip(data: &T) { + let mut buf = vec![0; data.encoded_len() + 2]; + let len = data.encode(&mut buf).unwrap(); + let buf = &buf[..len]; + let (readed, decoded) = DataRef::decode(&buf[..len]).unwrap(); + let decoded = T::from_ref(decoded).unwrap(); + assert_eq!(len, readed); + assert_eq!(data, &decoded); +} + +macro_rules! data_round_trip { + ($( + $(#[$attr:meta])* + $ty:ty + ),+$(,)?) => { + $( + paste::paste! { + $(#[$attr])* + #[quickcheck_macros::quickcheck] + fn [< data_round_trip_ $ty:snake >](data: $ty) { + data_round_trip(&data); + } + } + )* + }; +} + +type StringFilter = Filter; +type U64Filter = Filter; + +type QueryMessageStringString = QueryMessage; +type QueryMessageU64String = QueryMessage; +type QueryMessageStringU64 = QueryMessage; +type QueryMessageU64U64 = QueryMessage; + +type PushPullMessageString = PushPullMessage; +type PushPullMessageU64 = PushPullMessage; + +type MemberStringString = Member; +type MemberU64String = Member; +type MemberStringU64 = Member; +type MemberU64U64 = Member; + +type LeaveMessageString = LeaveMessage; +type LeaveMessageU64 = LeaveMessage; + +type JoinMessageString = JoinMessage; +type JoinMessageU64 = JoinMessage; + +type ConflictResponseMessageStringString = ConflictResponseMessage; +type ConflictResponseMessageU64U64 = ConflictResponseMessage; +type ConflictResponseMessageStringU64 = ConflictResponseMessage; + + +// QueryMessageStringString, +// QueryMessageU64U64, +// PushPullMessageString, +// PushPullMessageU64, + +type QueryResponseMessageStringString = QueryResponseMessage; +type QueryResponseMessageU64U64 = QueryResponseMessage; +type QueryResponseMessageStringU64 = QueryResponseMessage; + +data_round_trip!( + QueryMessageU64U64, +); + + +data_round_trip! { + ConflictResponseMessageStringString, + ConflictResponseMessageU64U64, + ConflictResponseMessageStringU64, + JoinMessageString, + JoinMessageU64, + LeaveMessageString, + LeaveMessageU64, + MemberStringString, + MemberU64String, + MemberStringU64, + MemberU64U64, + Tags, + TagFilter, + StringFilter, + U64Filter, + UserEvent, + UserEvents, + UserEventMessage, + QueryResponseMessageStringString, + QueryResponseMessageU64U64, + QueryResponseMessageStringU64, + #[cfg(feature = "encryption")] + KeyRequestMessage, + #[cfg(feature = "encryption")] + KeyResponseMessage, +} + #[derive(Clone, Debug)] enum Message { /// Leave message @@ -28,18 +128,19 @@ enum Message { KeyResponse(KeyResponseMessage), } - impl Arbitrary for Message where I: Arbitrary + Hash + Eq, A: Arbitrary, { - fn arbitrary(g: &mut Gen) -> Self { + fn arbitrary(g: &mut Gen) -> Self { loop { let variant = MessageType::arbitrary(g); return match variant { - MessageType::ConflictResponse => Message::ConflictResponse(ConflictResponseMessage::arbitrary(g)), + MessageType::ConflictResponse => { + Message::ConflictResponse(ConflictResponseMessage::arbitrary(g)) + } MessageType::Join => Message::Join(JoinMessage::arbitrary(g)), MessageType::Leave => Message::Leave(LeaveMessage::arbitrary(g)), MessageType::PushPull => Message::PushPull(PushPullMessage::arbitrary(g)), @@ -56,9 +157,195 @@ where } } +fn encode(data: &T) -> Bytes { + encode_message_to_bytes(data).unwrap() +} + +fn encode_relay(data: &T, node: &Node) -> Bytes +where + I: Data, + A: Data, + T: Encodable, +{ + encode_relay_message_to_bytes(data, node).unwrap() +} + +fn encodable_round_trip(msg: Message, node: Option>) -> bool +where + I: Data + Eq + Hash, + A: Data + PartialEq, +{ + macro_rules! encode_variant { + (< $($g:ty), +$(,)? > $variant:ident ($ty:ty) <- $input:ident) => {{ + let data = encode(&$input); + assert_eq!(data.len(), encoded_message_len(&$input), "message: length mismatch"); + let decoded = super::decode_message :: < $($g),* > (&data).unwrap(); + let MessageRef::$variant(decoded) = decoded else { return false }; + + let owned = <$ty as Data>::from_ref(decoded).unwrap(); + assert_eq!($input, owned, "message: decoded mismatch"); + true + }}; + (@relay< $($g:ty), +$(,)? > $variant:ident ($ty:ty) <- ($input:ident, $node:ident)) => {{ + let data = encode_relay(&$input, &$node); + assert_eq!(data.len(), encoded_relay_message_len(&$input, &$node), "relay message: length mismatch"); + let decoded = super::decode_message :: < $($g),* > (&data).unwrap(); + let MessageRef::Relay { node, payload, .. } = decoded else { return false }; + assert_eq!( as Data>::from_ref(node).unwrap(), $node, "relay message: node mismatch"); -#[quickcheck_macros::quickcheck] -fn message_arbitrary(msg: Message, node: Option>) -> bool { - todo!() + let decoded = super::decode_message :: < $($g),* > (&payload).unwrap(); + let MessageRef::$variant(decoded) = decoded else { return false }; + + let owned = <$ty as Data>::from_ref(decoded).unwrap(); + assert_eq!($input, owned, "relay message: decoded mismatch"); + true + }}; + } + + match node { + Some(node) => match msg { + Message::Leave(leave_message) => { + encode_variant!(@relay Leave(LeaveMessage) <- (leave_message, node)) + } + Message::Join(join_message) => { + encode_variant!(@relay Join(JoinMessage) <- (join_message, node)) + } + Message::PushPull(push_pull_message) => { + encode_variant!(@relay PushPull(PushPullMessage) <- (push_pull_message, node)) + } + Message::UserEvent(user_event_message) => { + encode_variant!(@relay UserEvent(UserEventMessage) <- (user_event_message, node)) + } + Message::Query(query_message) => { + encode_variant!(@relay Query(QueryMessage) <- (query_message, node)) + } + Message::QueryResponse(query_response_message) => { + encode_variant!(@relay QueryResponse(QueryResponseMessage) <- (query_response_message, node)) + } + Message::ConflictResponse(conflict_response_message) => { + encode_variant!(@relay ConflictResponse(ConflictResponseMessage) <- (conflict_response_message, node)) + } + #[cfg(feature = "encryption")] + Message::KeyRequest(key_request_message) => { + encode_variant!(@relay KeyRequest(KeyRequestMessage) <- (key_request_message, node)) + } + #[cfg(feature = "encryption")] + Message::KeyResponse(key_response_message) => { + encode_variant!(@relay KeyResponse(KeyResponseMessage) <- (key_response_message, node)) + } + }, + None => match msg { + Message::Leave(msg) => encode_variant!( Leave(LeaveMessage) <- msg), + Message::Join(join_message) => encode_variant!( Join(JoinMessage) <- join_message), + Message::PushPull(push_pull_message) => { + encode_variant!( PushPull(PushPullMessage) <- push_pull_message) + } + Message::UserEvent(user_event_message) => { + encode_variant!( UserEvent(UserEventMessage) <- user_event_message) + } + Message::Query(query_message) => { + encode_variant!( Query(QueryMessage) <- query_message) + } + Message::QueryResponse(query_response_message) => { + encode_variant!( QueryResponse(QueryResponseMessage) <- query_response_message) + } + Message::ConflictResponse(conflict_response_message) => { + encode_variant!( ConflictResponse(ConflictResponseMessage) <- conflict_response_message) + } + #[cfg(feature = "encryption")] + Message::KeyRequest(key_request_message) => { + encode_variant!( KeyRequest(KeyRequestMessage) <- key_request_message) + } + #[cfg(feature = "encryption")] + Message::KeyResponse(key_response_message) => { + encode_variant!( KeyResponse(KeyResponseMessage) <- key_response_message) + } + }, + } +} + +macro_rules! encodable_round_trip { + (@message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< message _encodable_round_trip_ $a:snake _ $b:snake >](msg: Message<$a, $b>, node: Option>) -> bool { + encodable_round_trip(msg, node) + } + } + )* + }; + (@query_message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< query_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: QueryMessage<$a, $b>, node: Option>) -> bool { + encodable_round_trip(Message::Query(msg), node) + } + } + )* + }; + (@query_response_message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< query_response_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: QueryResponseMessage<$a, $b>, node: Option>) -> bool { + encodable_round_trip(Message::QueryResponse(msg), node) + } + } + )* + }; + (@conflict_response_message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< conflict_response_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: ConflictResponseMessage<$a, $b>, node: Option>) -> bool { + encodable_round_trip(Message::ConflictResponse(msg), node) + } + } + )* + }; + (@join_message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< join_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: JoinMessage<$a>, node: Option>) -> bool { + encodable_round_trip(Message::Join(msg), node) + } + } + )* + }; } +encodable_round_trip!( + @message + , + , + , + , +); + +encodable_round_trip!( + @query_message + , + // , + // , + // , +); + +encodable_round_trip!( + @query_response_message + , + // , + // , + // , +); + +encodable_round_trip!( + @conflict_response_message + , + // , + // , + // , +); + diff --git a/serf-core/src/types/user_event/message.rs b/serf-core/src/types/user_event/message.rs index ff9724b..cc6bebf 100644 --- a/serf-core/src/types/user_event/message.rs +++ b/serf-core/src/types/user_event/message.rs @@ -271,7 +271,7 @@ impl Data for UserEventMessage { } #[cfg(debug_assertions)] - super::super::debug_assert_write_eq(offset, self.encoded_len()); + super::super::debug_assert_write_eq::(offset, self.encoded_len()); Ok(offset) } diff --git a/serf-core/src/types/user_event/user_events.rs b/serf-core/src/types/user_event/user_events.rs index 3fe019a..9edd6f5 100644 --- a/serf-core/src/types/user_event/user_events.rs +++ b/serf-core/src/types/user_event/user_events.rs @@ -77,6 +77,8 @@ impl<'a> DataRef<'a, UserEvents> for UserEventsRef<'a> { offset += size; } EVENTS_BYTE => { + offset += 1; + let readed = super::skip(WireType::LengthDelimited, &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = events_offsets { if *fnso > offset { @@ -170,8 +172,15 @@ impl Data for UserEvents { .events .iter() .try_fold(&mut offset, |offset, ev| { + if *offset >= buf_len { + return Err(EncodeError::insufficient_buffer( + self.encoded_len(), + buf_len, + )); + } + buf[*offset] = EVENTS_BYTE; + *offset += 1; *offset += ev.encode_length_delimited(&mut buf[*offset..])?; - Ok(offset) }) .map(|offset| *offset) From d16e1f3de66306bb2b3e60ba00670a63d08fdf5e Mon Sep 17 00:00:00 2001 From: al8n Date: Sun, 2 Mar 2025 16:27:39 +0800 Subject: [PATCH 13/39] Fix all types encoding/decoding tests --- serf-core/src/types/filter.rs | 5 +- serf-core/src/types/message.rs | 28 +++-- serf-core/src/types/push_pull.rs | 71 ++++++----- serf-core/src/types/query.rs | 5 +- serf-core/src/types/query/response.rs | 4 +- serf-core/src/types/tests.rs | 166 ++++++++++++++++++++++---- serf-core/src/types/version.rs | 8 +- 7 files changed, 209 insertions(+), 78 deletions(-) diff --git a/serf-core/src/types/filter.rs b/serf-core/src/types/filter.rs index 376992c..f4c0ed1 100644 --- a/serf-core/src/types/filter.rs +++ b/serf-core/src/types/filter.rs @@ -162,10 +162,7 @@ where .with_offsets(start, end), ) } else { - Self::Id( - RepeatedDecoder::new(FILTER_ID_TAG, I::WIRE_TYPE, buf) - .with_nums(0), - ) + Self::Id(RepeatedDecoder::new(FILTER_ID_TAG, I::WIRE_TYPE, buf).with_nums(0)) }, )) } diff --git a/serf-core/src/types/message.rs b/serf-core/src/types/message.rs index 5d10d12..915ec2a 100644 --- a/serf-core/src/types/message.rs +++ b/serf-core/src/types/message.rs @@ -364,7 +364,10 @@ where } /// Encode a relay message into a Bytes. -pub fn encode_relay_message_to_bytes(msg: &T, node: &Node) -> Result +pub fn encode_relay_message_to_bytes( + msg: &T, + node: &Node, +) -> Result where T: Encodable, I: Data, @@ -392,9 +395,13 @@ where return Err(EncodeError::TooLarge); } - offset += (encoded_len as u32).encode(&mut buf[offset..]).map_err(|e| e.update(encoded_message_len(msg), buf_len))?; + offset += (encoded_len as u32) + .encode(&mut buf[offset..]) + .map_err(|e| e.update(encoded_message_len(msg), buf_len))?; - offset += msg.encode(&mut buf[offset..]).map_err(|e| e.update(encoded_message_len(msg), buf_len))?; + offset += msg + .encode(&mut buf[offset..]) + .map_err(|e| e.update(encoded_message_len(msg), buf_len))?; #[cfg(debug_assertions)] { @@ -406,7 +413,11 @@ where } /// Encode a relay message into a buffer. -pub fn encode_relay_message(msg: &T, node: &Node, buf: &mut [u8]) -> Result +pub fn encode_relay_message( + msg: &T, + node: &Node, + buf: &mut [u8], +) -> Result where T: Encodable, I: Data, @@ -439,9 +450,12 @@ where return Err(EncodeError::TooLarge); } - offset += (encoded_len as u32).encode(&mut buf[offset..]).map_err(|e| e.update(encoded_relay_message_len(msg, node), buf_len))?; - offset += msg.encode(&mut buf[offset..]).map_err(|e| e.update(encoded_relay_message_len(msg, node), buf_len))?; - + offset += (encoded_len as u32) + .encode(&mut buf[offset..]) + .map_err(|e| e.update(encoded_relay_message_len(msg, node), buf_len))?; + offset += msg + .encode(&mut buf[offset..]) + .map_err(|e| e.update(encoded_relay_message_len(msg, node), buf_len))?; #[cfg(debug_assertions)] { diff --git a/serf-core/src/types/push_pull.rs b/serf-core/src/types/push_pull.rs index 08640a3..20f0291 100644 --- a/serf-core/src/types/push_pull.rs +++ b/serf-core/src/types/push_pull.rs @@ -106,11 +106,15 @@ const QUERY_LTIME_TAG: u8 = 6; const LTIME_BYTE: u8 = merge(WireType::Varint, LTIME_TAG); const STATUS_LTIMES_BYTE: u8 = merge(WireType::LengthDelimited, STATUS_LTIMES_TAG); -const LEFT_MEMBERS_BYTE: u8 = merge(WireType::LengthDelimited, LEFT_MEMBERS_TAG); const EVENT_LTIME_BYTE: u8 = merge(WireType::Varint, EVENT_LTIME_TAG); const EVENTS_BYTE: u8 = merge(WireType::LengthDelimited, EVENTS_TAG); const QUERY_LTIME_BYTE: u8 = merge(WireType::Varint, QUERY_LTIME_TAG); +#[inline] +const fn left_members_byte() -> u8 { + merge(I::WIRE_TYPE, LEFT_MEMBERS_TAG) +} + /// Used when doing a state exchange. This /// is a relatively large message, but is sent infrequently #[viewit::viewit(vis_all = "", getters(vis_all = "pub"), setters(skip))] @@ -179,6 +183,8 @@ where let mut num_events = 0; let mut query_ltime = None; + let left_members_byte = left_members_byte::(); + while offset < buf_len { match buf[offset] { LTIME_BYTE => { @@ -196,6 +202,7 @@ where ltime = Some(v); } STATUS_LTIMES_BYTE => { + offset += 1; let readed = skip(WireType::LengthDelimited, &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = status_ltimes_offsets { if *fnso > offset { @@ -211,8 +218,9 @@ where num_status_ltimes += 1; offset += readed; } - LEFT_MEMBERS_BYTE => { - let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + b if b == left_members_byte => { + offset += 1; + let readed = skip(I::WIRE_TYPE, &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = left_members_offsets { if *fnso > offset { *fnso = offset - 1; @@ -242,6 +250,7 @@ where event_ltime = Some(v); } EVENTS_BYTE => { + offset += 1; let readed = skip(WireType::LengthDelimited, &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = events_offsets { if *fnso > offset { @@ -285,19 +294,19 @@ where offset, Self { ltime: ltime.ok_or_else(|| DecodeError::missing_field("PushPullMessage", "ltime"))?, - status_ltimes: if let Some((start, end)) = events_offsets { + status_ltimes: if let Some((start, end)) = status_ltimes_offsets { RepeatedDecoder::new(STATUS_LTIMES_TAG, WireType::LengthDelimited, buf) .with_nums(num_status_ltimes) .with_offsets(start, end) } else { RepeatedDecoder::new(STATUS_LTIMES_TAG, WireType::LengthDelimited, buf) }, - left_members: if let Some((start, end)) = events_offsets { - RepeatedDecoder::new(LEFT_MEMBERS_TAG, WireType::LengthDelimited, buf) + left_members: if let Some((start, end)) = left_members_offsets { + RepeatedDecoder::new(LEFT_MEMBERS_TAG, I::WIRE_TYPE, buf) .with_nums(num_left_members) .with_offsets(start, end) } else { - RepeatedDecoder::new(LEFT_MEMBERS_TAG, WireType::LengthDelimited, buf) + RepeatedDecoder::new(LEFT_MEMBERS_TAG, I::WIRE_TYPE, buf) }, event_ltime: event_ltime .ok_or_else(|| DecodeError::missing_field("PushPullMessage", "event_ltime"))?, @@ -326,12 +335,6 @@ where where Self: Sized, { - let left_members = val - .left_members - .iter::() - .map(|res| res.and_then(Data::from_ref)) - .collect::, DecodeError>>()?; - Ok(Self { ltime: val.ltime, status_ltimes: val @@ -339,7 +342,11 @@ where .iter::<(I, LamportTime)>() .map(|res| res.and_then(Data::from_ref)) .collect::, DecodeError>>()?, - left_members, + left_members: val + .left_members + .iter::() + .map(|res| res.and_then(Data::from_ref)) + .collect::, DecodeError>>()?, event_ltime: val.event_ltime, events: val .events @@ -366,13 +373,15 @@ where .iter() .map(|id| 1 + id.encoded_len_with_length_delimited()) .sum::(); + len += 1 + self.event_ltime.encoded_len(); - len += 1 - + self - .events - .iter() - .map(|e| 1 + e.encoded_len_with_length_delimited()) - .sum::(); + + len += self + .events + .iter() + .map(|e| 1 + e.encoded_len_with_length_delimited()) + .sum::(); + len += 1 + self.query_ltime.encoded_len(); len @@ -395,28 +404,25 @@ where offset += 1; offset += self.ltime.encode(&mut buf[offset..])?; - bail!(self(offset, buf_len)); - buf[offset] = STATUS_LTIMES_BYTE; - offset += 1; - self .status_ltimes .iter() .try_fold(&mut offset, |off, (k, v)| { bail!(self(*off, buf_len)); - buf[*off] = LEFT_MEMBERS_BYTE; + buf[*off] = STATUS_LTIMES_BYTE; *off += 1; *off += TupleEncoder::new(k, v).encode_with_length_delimited(&mut buf[*off..])?; Ok(off) }) .map_err(|e: EncodeError| e.update(self.encoded_len(), buf_len))?; + let left_members_byte = left_members_byte::(); self .left_members .iter() .try_fold(&mut offset, |off, id| { bail!(self(*off, buf_len)); - buf[*off] = LEFT_MEMBERS_BYTE; + buf[*off] = left_members_byte; *off += 1; *off += id.encode_length_delimited(&mut buf[*off..])?; Ok(off) @@ -428,10 +434,6 @@ where offset += 1; offset += self.event_ltime.encode(&mut buf[offset..])?; - bail!(self(offset, buf_len)); - buf[offset] = EVENTS_BYTE; - offset += 1; - self .events .iter() @@ -548,19 +550,20 @@ where .iter() .try_fold(&mut offset, |off, (k, v)| { bail!(self(*off, buf_len)); - buf[*off] = LEFT_MEMBERS_BYTE; + buf[*off] = STATUS_LTIMES_BYTE; *off += 1; *off += TupleEncoder::new(k, v).encode_with_length_delimited(&mut buf[*off..])?; Ok(off) }) .map_err(|e: EncodeError| e.update(self.encoded_len_in(), buf_len))?; + let left_members_byte = left_members_byte::(); self .left_members .iter() .try_fold(&mut offset, |off, id| { bail!(self(*off, buf_len)); - buf[*off] = LEFT_MEMBERS_BYTE; + buf[*off] = left_members_byte; *off += 1; *off += id.encode_length_delimited(&mut buf[*off..])?; Ok(off) @@ -572,10 +575,6 @@ where offset += 1; offset += self.event_ltime.encode(&mut buf[offset..])?; - bail!(self(offset, buf_len)); - buf[offset] = EVENTS_BYTE; - offset += 1; - self .events .iter() diff --git a/serf-core/src/types/query.rs b/serf-core/src/types/query.rs index 8812b70..d899594 100644 --- a/serf-core/src/types/query.rs +++ b/serf-core/src/types/query.rs @@ -249,11 +249,14 @@ where offset += 1; let (o, v) = - , A::Ref<'_>> as DataRef<'_, Node>>::decode(&buf[offset..])?; + , A::Ref<'_>> as DataRef<'_, Node>>::decode_length_delimited( + &buf[offset..], + )?; offset += o; from = Some(v); } FILTERS_BYTE => { + offset += 1; let readed = skip(WireType::LengthDelimited, &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = filters_offsets { if *fnso > offset { diff --git a/serf-core/src/types/query/response.rs b/serf-core/src/types/query/response.rs index f2e9e99..633233e 100644 --- a/serf-core/src/types/query/response.rs +++ b/serf-core/src/types/query/response.rs @@ -159,7 +159,9 @@ where offset += 1; let (o, v) = - , A::Ref<'_>> as DataRef<'_, Node>>::decode_length_delimited(&buf[offset..])?; + , A::Ref<'_>> as DataRef<'_, Node>>::decode_length_delimited( + &buf[offset..], + )?; offset += o; from = Some(v); } diff --git a/serf-core/src/types/tests.rs b/serf-core/src/types/tests.rs index d81cf0a..a09b147 100644 --- a/serf-core/src/types/tests.rs +++ b/serf-core/src/types/tests.rs @@ -16,6 +16,14 @@ fn data_round_trip(data: &T) { let decoded = T::from_ref(decoded).unwrap(); assert_eq!(len, readed); assert_eq!(data, &decoded); + + let mut buf = vec![0; data.encoded_len_with_length_delimited() + 2]; + let len = data.encode_length_delimited(&mut buf).unwrap(); + let buf = &buf[..len]; + let (readed, decoded) = DataRef::decode_length_delimited(&buf[..len]).unwrap(); + let decoded = T::from_ref(decoded).unwrap(); + assert_eq!(len, readed); + assert_eq!(data, &decoded); } macro_rules! data_round_trip { @@ -61,21 +69,10 @@ type ConflictResponseMessageStringString = ConflictResponseMessage; type ConflictResponseMessageStringU64 = ConflictResponseMessage; - -// QueryMessageStringString, -// QueryMessageU64U64, -// PushPullMessageString, -// PushPullMessageU64, - type QueryResponseMessageStringString = QueryResponseMessage; type QueryResponseMessageU64U64 = QueryResponseMessage; type QueryResponseMessageStringU64 = QueryResponseMessage; -data_round_trip!( - QueryMessageU64U64, -); - - data_round_trip! { ConflictResponseMessageStringString, ConflictResponseMessageU64U64, @@ -95,6 +92,12 @@ data_round_trip! { UserEvent, UserEvents, UserEventMessage, + PushPullMessageU64, + PushPullMessageString, + QueryMessageStringString, + QueryMessageU64String, + QueryMessageStringU64, + QueryMessageU64U64, QueryResponseMessageStringString, QueryResponseMessageU64U64, QueryResponseMessageStringU64, @@ -269,8 +272,18 @@ macro_rules! encodable_round_trip { $( paste::paste! { #[quickcheck_macros::quickcheck] - fn [< message _encodable_round_trip_ $a:snake _ $b:snake >](msg: Message<$a, $b>, node: Option>) -> bool { - encodable_round_trip(msg, node) + fn [< message _encodable_round_trip_ $a:snake _ $b:snake >](msg: Message<$a, $b>) -> bool { + encodable_round_trip(msg, None) + } + } + )* + }; + (@relay_message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< relay_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: Message<$a, $b>, node: Node<$a, $b>) -> bool { + encodable_round_trip(msg, Some(node)) } } )* @@ -295,6 +308,36 @@ macro_rules! encodable_round_trip { } )* }; + (@join_message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< join_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: JoinMessage<$a>, node: Option>) -> bool { + encodable_round_trip(Message::Join(msg), node) + } + } + )* + }; + (@push_pull_message $(<$a:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< push_pull_message _encodable_round_trip_ $a:snake >](msg: PushPullMessage<$a>, node: Option>) -> bool { + encodable_round_trip(Message::PushPull(msg), node) + } + } + )* + }; + (@leave_message $(<$a:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< leave_message _encodable_round_trip_ $a:snake >](msg: LeaveMessage<$a>, node: Option>) -> bool { + encodable_round_trip(Message::Leave(msg), node) + } + } + )* + }; (@conflict_response_message $(<$a:ty, $b:ty>),+$(,)?) => { $( paste::paste! { @@ -305,12 +348,32 @@ macro_rules! encodable_round_trip { } )* }; - (@join_message $(<$a:ty, $b:ty>),+$(,)?) => { + (@user_event_message $(<$a:ty, $b:ty>),+$(,)?) => { $( paste::paste! { #[quickcheck_macros::quickcheck] - fn [< join_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: JoinMessage<$a>, node: Option>) -> bool { - encodable_round_trip(Message::Join(msg), node) + fn [< user_event_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: UserEventMessage, node: Option>) -> bool { + encodable_round_trip(Message::UserEvent(msg), node) + } + } + )* + }; + (@key_request_message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< key_request_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: KeyRequestMessage, node: Option>) -> bool { + encodable_round_trip(Message::KeyRequest(msg), node) + } + } + )* + }; + (@key_response_message $(<$a:ty, $b:ty>),+$(,)?) => { + $( + paste::paste! { + #[quickcheck_macros::quickcheck] + fn [< key_response_message _encodable_round_trip_ $a:snake _ $b:snake >](msg: KeyResponseMessage, node: Option>) -> bool { + encodable_round_trip(Message::KeyResponse(msg), node) } } )* @@ -325,27 +388,80 @@ encodable_round_trip!( , ); +encodable_round_trip!( + @relay_message + , + , + , + , +); + encodable_round_trip!( @query_message , - // , - // , - // , + , + , + , ); encodable_round_trip!( @query_response_message , - // , - // , - // , + , + , + , ); encodable_round_trip!( @conflict_response_message , - // , - // , - // , + , + , + , +); + +encodable_round_trip!( + @join_message + , + , + , + , +); + +encodable_round_trip!( + @push_pull_message + , + , +); + +encodable_round_trip!( + @leave_message + , + , +); + +encodable_round_trip!( + @user_event_message + , + , + , + , ); +#[cfg(feature = "encryption")] +encodable_round_trip!( + @key_request_message + , + , + , + , +); + +#[cfg(feature = "encryption")] +encodable_round_trip!( + @key_response_message + , + , + , + , +); diff --git a/serf-core/src/types/version.rs b/serf-core/src/types/version.rs index 710d49f..b603f7d 100644 --- a/serf-core/src/types/version.rs +++ b/serf-core/src/types/version.rs @@ -128,8 +128,8 @@ mod tests { let _ = DelegateVersion::arbitrary(&mut data).unwrap(); assert_eq!(u8::from(DelegateVersion::V1), 1u8); - assert_eq!(DelegateVersion::V1.to_string(), "V1"); - assert_eq!(DelegateVersion::Unknown(2).to_string(), "Unknown(2)"); + assert_eq!(DelegateVersion::V1.to_string(), "v1"); + assert_eq!(DelegateVersion::Unknown(2).to_string(), "unknown(2)"); assert_eq!(DelegateVersion::from(1), DelegateVersion::V1); assert_eq!(DelegateVersion::from(2), DelegateVersion::Unknown(2)); } @@ -143,8 +143,8 @@ mod tests { let mut data = Unstructured::new(&buf); let _ = ProtocolVersion::arbitrary(&mut data).unwrap(); assert_eq!(u8::from(ProtocolVersion::V1), 1); - assert_eq!(ProtocolVersion::V1.to_string(), "V1"); - assert_eq!(ProtocolVersion::Unknown(2).to_string(), "Unknown(2)"); + assert_eq!(ProtocolVersion::V1.to_string(), "v1"); + assert_eq!(ProtocolVersion::Unknown(2).to_string(), "unknown(2)"); assert_eq!(ProtocolVersion::from(1), ProtocolVersion::V1); assert_eq!(ProtocolVersion::from(2), ProtocolVersion::Unknown(2)); } From 3dcf331b8302e1044f3ef37738d28d218f2f51da Mon Sep 17 00:00:00 2001 From: al8n Date: Sun, 2 Mar 2025 23:31:46 +0800 Subject: [PATCH 14/39] Fix unit tests --- serf-core/Cargo.toml | 12 ++ serf-core/src/coalesce/member.rs | 2 +- serf-core/src/coalesce/user.rs | 2 +- serf-core/src/lib.rs | 2 +- serf-core/src/serf/base/tests.rs | 2 +- serf-core/src/serf/base/tests/serf.rs | 17 +- .../src/serf/base/tests/serf/delegate.rs | 10 +- serf-core/src/serf/base/tests/serf/event.rs | 191 +++++++++--------- serf-core/src/serf/base/tests/serf/join.rs | 30 +-- serf-core/src/serf/base/tests/serf/leave.rs | 12 +- serf-core/src/serf/base/tests/serf/reap.rs | 6 +- .../src/serf/base/tests/serf/reconnect.rs | 4 +- .../src/serf/base/tests/serf/snapshot.rs | 70 +++---- serf-core/src/serf/delegate.rs | 5 +- serf-core/src/types.rs | 40 ++++ serf-core/src/types/filter.rs | 13 +- serf-core/src/types/filter/tag_filter.rs | 7 + serf-core/src/types/query/response.rs | 12 +- serf/src/async_std.rs | 50 +---- serf/src/lib.rs | 50 ++--- serf/src/smol.rs | 50 +---- serf/src/tokio.rs | 50 +---- serf/test/main.rs | 48 ++--- serf/test/main/net/coordinates.rs | 8 +- serf/test/main/net/delegate/local_state.rs | 8 +- serf/test/main/net/delegate/nodemeta.rs | 12 +- serf/test/main/net/delegate/ping_delegate.rs | 12 +- serf/test/main/net/delegate/remote_state.rs | 8 +- serf/test/main/net/event.rs | 3 - serf/test/main/net/event/default_query.rs | 8 +- serf/test/main/net/event/event_user.rs | 8 +- .../main/net/event/event_user_size_limit.rs | 8 +- serf/test/main/net/event/events_failed.rs | 8 +- serf/test/main/net/event/events_join.rs | 8 +- serf/test/main/net/event/events_leave.rs | 8 +- ...events_leave_avoid_infinite_rebroadcast.rs | 8 +- serf/test/main/net/event/query.rs | 8 +- serf/test/main/net/event/query_deduplicate.rs | 8 +- serf/test/main/net/event/query_filter.rs | 8 +- serf/test/main/net/event/query_old_message.rs | 8 +- .../net/event/query_params_encode_filters.rs | 64 ------ serf/test/main/net/event/query_same_clock.rs | 8 +- serf/test/main/net/event/query_size_limit.rs | 8 +- .../net/event/query_size_limit_increased.rs | 7 +- .../net/event/remove_failed_events_leave.rs | 8 +- serf/test/main/net/event/should_process.rs | 8 +- .../main/net/event/user_event_old_message.rs | 8 +- .../main/net/event/user_event_same_clock.rs | 8 +- serf/test/main/net/get_queue_max.rs | 8 +- .../test/main/net/join/intent_buffer_early.rs | 8 +- serf/test/main/net/join/intent_newer.rs | 8 +- serf/test/main/net/join/intent_old_message.rs | 8 +- .../main/net/join/intent_reset_leaving.rs | 8 +- serf/test/main/net/join/join_cancel.rs | 8 +- serf/test/main/net/join/join_ignore_old.rs | 8 +- serf/test/main/net/join/join_leave.rs | 8 +- serf/test/main/net/join/join_leave_join.rs | 8 +- serf/test/main/net/join/leave_ltime.rs | 8 +- serf/test/main/net/join/pending_intent.rs | 8 +- serf/test/main/net/join/pending_intents.rs | 8 +- .../test/main/net/leave/force_leave_failed.rs | 8 +- .../main/net/leave/force_leave_leaving.rs | 8 +- serf/test/main/net/leave/force_leave_left.rs | 8 +- .../main/net/leave/intent_buffer_early.rs | 8 +- serf/test/main/net/leave/intent_newer.rs | 8 +- .../test/main/net/leave/intent_old_message.rs | 8 +- .../main/net/leave/rejoin_different_role.rs | 8 +- serf/test/main/net/leave/snapshot_recovery.rs | 8 +- serf/test/main/net/local_member.rs | 8 +- serf/test/main/net/name_resolution.rs | 8 +- serf/test/main/net/num_nodes.rs | 8 +- serf/test/main/net/reap/handler.rs | 8 +- serf/test/main/net/reap/handler_shutdown.rs | 8 +- serf/test/main/net/reap/reap.rs | 8 +- serf/test/main/net/reconnect/reconnect.rs | 8 +- serf/test/main/net/reconnect/same_ip.rs | 8 +- serf/test/main/net/reconnect/timeout.rs | 8 +- serf/test/main/net/remove/failed_node.rs | 8 +- .../main/net/remove/failed_node_ourself.rs | 8 +- .../test/main/net/remove/failed_node_prune.rs | 8 +- serf/test/main/net/role.rs | 8 +- serf/test/main/net/set_tags.rs | 8 +- serf/test/main/net/snapshot/snapshoter.rs | 8 +- .../net/snapshot/snapshoter_force_compact.rs | 8 +- .../main/net/snapshot/snapshoter_leave.rs | 8 +- .../net/snapshot/snapshoter_leave_rejoin.rs | 8 +- .../main/net/snapshot/snapshoter_recovery.rs | 8 +- serf/test/main/net/state.rs | 8 +- serf/test/main/net/stats.rs | 8 +- serf/test/main/net/update.rs | 8 +- serf/test/main/net/write_keyring_file.rs | 14 +- 91 files changed, 510 insertions(+), 783 deletions(-) delete mode 100644 serf/test/main/net/event/query_params_encode_filters.rs diff --git a/serf-core/Cargo.toml b/serf-core/Cargo.toml index 596b0b3..a56c69f 100644 --- a/serf-core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -14,8 +14,20 @@ categories.workspace = true [features] default = ["metrics"] metrics = ["memberlist-core/metrics", "dep:metrics"] + encryption = ["memberlist-core/encryption", "base64", "serde"] +crc32 = ["memberlist-core/crc32"] +murmur3 = ["memberlist-core/murmur3"] +xxhash64 = ["memberlist-core/xxhash64"] +xxhash32 = ["memberlist-core/xxhash32"] +xxhash3 = ["memberlist-core/xxhash3"] + +snappy = ["memberlist-core/snappy"] +zstd = ["memberlist-core/zstd"] +lz4 = ["memberlist-core/lz4"] +brotli = ["memberlist-core/brotli"] + serde = [ "dep:serde", "dep:humantime-serde", diff --git a/serf-core/src/coalesce/member.rs b/serf-core/src/coalesce/member.rs index db325b5..2a29c4d 100644 --- a/serf-core/src/coalesce/member.rs +++ b/serf-core/src/coalesce/member.rs @@ -136,7 +136,7 @@ where // type Transport = UnimplementedTransport< // SmolStr, // SocketAddrResolver, -// Lpe, +// // TokioRuntime, // >; diff --git a/serf-core/src/coalesce/user.rs b/serf-core/src/coalesce/user.rs index 712d505..e5f868a 100644 --- a/serf-core/src/coalesce/user.rs +++ b/serf-core/src/coalesce/user.rs @@ -114,7 +114,7 @@ where // type Transport = UnimplementedTransport< // SmolStr, // SocketAddrResolver, -// Lpe, +// // TokioRuntime, // >; diff --git a/serf-core/src/lib.rs b/serf-core/src/lib.rs index 9737fdb..ddd2038 100644 --- a/serf-core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -50,7 +50,7 @@ pub mod tests { pub use memberlist_core::tests::{AnyError, next_socket_addr_v4, next_socket_addr_v6}; pub use paste; - // pub use super::serf::base::tests::{serf::*, *}; + pub use super::serf::base::tests::{serf::*, *}; /// Add `test` prefix to the predefined unit test fn with a given [`Runtime`](memberlist_core::agnostic_lite::RuntimeLite) #[cfg(any(feature = "test", test))] diff --git a/serf-core/src/serf/base/tests.rs b/serf-core/src/serf/base/tests.rs index 658917a..b66ff27 100644 --- a/serf-core/src/serf/base/tests.rs +++ b/serf-core/src/serf/base/tests.rs @@ -20,7 +20,7 @@ use crate::{ use super::*; -// pub(crate) mod serf; +pub(crate) mod serf; fn test_config() -> Options { let mut opts = Options::new(); diff --git a/serf-core/src/serf/base/tests/serf.rs b/serf-core/src/serf/base/tests/serf.rs index 23769f9..7039b20 100644 --- a/serf-core/src/serf/base/tests/serf.rs +++ b/serf-core/src/serf/base/tests/serf.rs @@ -2,9 +2,11 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use memberlist_core::{tests::AnyError, transport::Id}; -use crate::types::{Member, MemberStatus, Tags}; - -use crate::{event::EventProducer, types::MemberState}; +use crate::{ + event::EventProducer, + options::MemberlistOptions, + types::{Member, MemberState, MemberStatus, Tags}, +}; use super::*; @@ -782,7 +784,7 @@ where /// Unit test for serf write keying file #[cfg(feature = "encryption")] pub async fn serf_write_keyring_file( - get_transport_opts: impl FnOnce(memberlist_core::proto::SecretKey) -> T::Options, + get_transport_opts: impl FnOnce(memberlist_core::proto::SecretKey) -> (T::Options, MemberlistOptions), ) where T: Transport, { @@ -800,9 +802,12 @@ pub async fn serf_write_keyring_file( let existing_bytes = general_purpose::STANDARD.decode(EXISTING).unwrap(); let sk = memberlist_core::proto::SecretKey::try_from(existing_bytes.as_slice()).unwrap(); + let (topts, mopts) = get_transport_opts(sk); let serf = Serf::::new( - get_transport_opts(sk), - test_config().with_keyring_file(Some(p.clone())), + topts, + test_config() + .with_keyring_file(Some(p.clone())) + .with_memberlist_options(mopts), ) .await .unwrap(); diff --git a/serf-core/src/serf/base/tests/serf/delegate.rs b/serf-core/src/serf/base/tests/serf/delegate.rs index ab9e0a5..0058591 100644 --- a/serf-core/src/serf/base/tests/serf/delegate.rs +++ b/serf-core/src/serf/base/tests/serf/delegate.rs @@ -78,17 +78,15 @@ where .await; // Verify - assert_eq!(buf[0], MessageType::PushPull.into(), "bad message type"); + assert_eq!(buf[0], u8::from(MessageType::PushPull), "bad message type"); // Attempt a decode - let pp = - crate::types::decode_message(&buf) - .unwrap(); + let pp = crate::types::decode_message::(&buf).unwrap(); let MessageRef::PushPull(pp) = pp else { panic!("bad message") }; - let pp = PushPullMessage::from_ref(pp).unwrap(); + let pp = as Data>::from_ref(pp).unwrap(); // Verify lamport clock assert_eq!(pp.ltime(), serfs[0].inner.clock.time(), "bad lamport clock"); @@ -147,7 +145,7 @@ where }), query_ltime: 100.into(), }; - + let buf = crate::types::encode_message_to_bytes(&pp).unwrap(); // Merge in fake state diff --git a/serf-core/src/serf/base/tests/serf/event.rs b/serf-core/src/serf/base/tests/serf/event.rs index d382e50..bb2d214 100644 --- a/serf-core/src/serf/base/tests/serf/event.rs +++ b/serf-core/src/serf/base/tests/serf/event.rs @@ -1,4 +1,6 @@ -use crate::types::{Filter, FilterType}; +use memberlist_core::proto::DataRef; + +use crate::types::{Filter, QueryResponseMessageRef, TagFilter}; use super::*; @@ -17,12 +19,11 @@ where .witness(((event_buffer + 1000) as u64).into()); assert!( !s1 - .handle_user_event( - Either::Right(UserEventMessage::default() + .handle_user_event(Either::Right( + UserEventMessage::default() .with_ltime(1.into()) .with_name("old".into()) - ) - ) + )) .await, "should not rebroadcast" ); @@ -44,19 +45,28 @@ where .with_ltime(1.into()) .with_name("first".into()) .with_payload(Bytes::from_static(b"test")); - assert!(s1.handle_user_event(Either::Right(msg)).await, "should rebroadcast"); + assert!( + s1.handle_user_event(Either::Right(msg)).await, + "should rebroadcast" + ); let msg = UserEventMessage::default() .with_ltime(1.into()) .with_name("first".into()) .with_payload(Bytes::from_static(b"newpayload")); - assert!(s1.handle_user_event(Either::Right(msg)).await, "should rebroadcast"); + assert!( + s1.handle_user_event(Either::Right(msg)).await, + "should rebroadcast" + ); let msg = UserEventMessage::default() .with_ltime(1.into()) .with_name("second".into()) .with_payload(Bytes::from_static(b"other")); - assert!(s1.handle_user_event(Either::Right(msg)).await, "should rebroadcast"); + assert!( + s1.handle_user_event(Either::Right(msg)).await, + "should rebroadcast" + ); test_user_events( event_rx.rx, @@ -548,42 +558,6 @@ where assert_eq!(params.timeout, timeout); } -/// Unit test for query params encode filters -pub async fn query_params_encode_filters(transport_opts: T::Options) -where - T: Transport, -{ - let opts = test_config(); - let s = Serf::::new(transport_opts, opts).await.unwrap(); - let mut params = s.default_query_param().await; - params - .filters - .push(Filter::Id(["foo".into(), "bar".into()].into())); - params.filters.push(Filter::Tag { - tag: "role".into(), - expr: "^web".into(), - }); - params.filters.push(Filter::Tag { - tag: "datacenter".into(), - expr: "aws$".into(), - }); - - let filters = params.encode_filters::>().unwrap(); - assert_eq!(filters.len(), 3); - - let (_, node_filt) = - decode_filter(&filters[0]).unwrap(); - assert_eq!(node_filt.ty(), FilterType::Id); - - let (_, tag_filt) = - decode_filter(&filters[1]).unwrap(); - assert_eq!(tag_filt.ty(), FilterType::Tag); - - let (_, tag_filt) = - decode_filter(&filters[2]).unwrap(); - assert_eq!(tag_filt.ty(), FilterType::Tag); -} - /// Unit test for should process functionallity pub async fn should_process(transport_opts: T::Options) where @@ -606,19 +580,21 @@ where .into_iter() .collect(), )); - params.filters.push(Filter::Tag { - tag: "role".into(), - expr: "^web".into(), - }); - params.filters.push(Filter::Tag { - tag: "datacenter".into(), - expr: "aws$".into(), - }); - - let filters = params.encode_filters::>().unwrap(); - assert_eq!(filters.len(), 3); + params.filters.push(Filter::Tag( + TagFilter::new() + .with_tag("role".into()) + .with_expr("^web".try_into().unwrap()), + )); + params.filters.push(Filter::Tag( + TagFilter::new() + .with_tag("datacenter".into()) + .with_expr("aws$".try_into().unwrap()), + )); - assert!(s.should_process_query(&filters)); + assert!( + s.should_process_query(Either::Right(¶ms.filters)) + .unwrap() + ); // Omit node let mut params = s.default_query_param().await; @@ -626,35 +602,41 @@ where .filters .push(Filter::Id(["foo".into(), "bar".into()].into())); - let filters = params.encode_filters::>().unwrap(); - assert!(!s.should_process_query(&filters)); + assert!( + !s.should_process_query(Either::Right(¶ms.filters)) + .unwrap() + ); // Filter on missing tag let mut params = s.default_query_param().await; - params.filters.push(Filter::Tag { - tag: "other".into(), - expr: "cool".into(), - }); + params.filters.push(Filter::Tag( + TagFilter::new() + .with_tag("other".into()) + .with_expr("cool".try_into().unwrap()), + )); - let filters = params.encode_filters::>().unwrap(); - assert!(!s.should_process_query(&filters)); + assert!( + !s.should_process_query(Either::Right(¶ms.filters)) + .unwrap() + ); // Bad tag let mut params = s.default_query_param().await; - params.filters.push(Filter::Tag { - tag: "role".into(), - expr: "db".into(), - }); + params.filters.push(Filter::Tag( + TagFilter::new() + .with_tag("role".into()) + .with_expr("db".try_into().unwrap()), + )); - let filters = params.encode_filters::>().unwrap(); - assert!(!s.should_process_query(&filters)); + assert!( + !s.should_process_query(Either::Right(¶ms.filters)) + .unwrap() + ); } /// Unit tests for the query old message -pub async fn query_old_message( - transport_opts: T::Options, - from: Node, -) where +pub async fn query_old_message(transport_opts: T::Options, from: Node) +where T: Transport, { let opts = test_config(); @@ -667,7 +649,7 @@ pub async fn query_old_message( assert!( !s1 .handle_query( - QueryMessage { + Either::Right(QueryMessage { ltime: 1.into(), id: 0, from, @@ -677,10 +659,11 @@ pub async fn query_old_message( timeout: Default::default(), name: "old".into(), payload: Bytes::new(), - }, + }), None ) - .await.unwrap(), + .await + .unwrap(), "should not rebroadcast" ); @@ -688,10 +671,8 @@ pub async fn query_old_message( } /// Unit tests for the query same clock -pub async fn query_same_clock( - transport_opts: T::Options, - from: Node, -) where +pub async fn query_same_clock(transport_opts: T::Options, from: Node) +where T: Transport, { let opts = test_config(); @@ -713,11 +694,16 @@ pub async fn query_same_clock( }; assert!( - s1.handle_query(Either::Right(msg.clone()), None).await, + s1.handle_query(Either::Right(msg.clone()), None) + .await + .unwrap(), "should rebroadcast" ); assert!( - !s1.handle_query(Either::Right(msg.clone()), None).await, + !s1 + .handle_query(Either::Right(msg.clone()), None) + .await + .unwrap(), "should not rebroadcast" ); @@ -734,11 +720,16 @@ pub async fn query_same_clock( }; assert!( - s1.handle_query(Either::Right(msg.clone()), None).await, + s1.handle_query(Either::Right(msg.clone()), None) + .await + .unwrap(), "should rebroadcast" ); assert!( - !s1.handle_query(Either::Right(msg.clone()), None).await, + !s1 + .handle_query(Either::Right(msg.clone()), None) + .await + .unwrap(), "should not rebroadcast" ); @@ -754,11 +745,16 @@ pub async fn query_same_clock( payload: Bytes::from_static(b"other"), }; assert!( - s1.handle_query(Either::Right(msg.clone()), None).await, + s1.handle_query(Either::Right(msg.clone()), None) + .await + .unwrap(), "should rebroadcast" ); assert!( - !s1.handle_query(Either::Right(msg.clone()), None).await, + !s1 + .handle_query(Either::Right(msg.clone()), None) + .await + .unwrap(), "should not rebroadcast" ); @@ -1016,7 +1012,7 @@ where payload: Default::default(), }; let query = QueryResponse::from_query(&mq, 3); - let mut response = QueryResponseMessage { + let response = QueryResponseMessage { ltime: mq.ltime, id: mq.id, from: s.advertise_node(), @@ -1028,12 +1024,21 @@ where qc.responses.insert(mq.ltime, query.clone()); } + let buf = response.encode_to_bytes().unwrap(); + let (_, resp_ref) = as DataRef< + '_, + QueryResponseMessage, + >>::decode(&buf) + .unwrap(); + // Send a few duplicate responses - s.handle_query_response(response.clone()).await; - s.handle_query_response(response.clone()).await; - response.flags |= QueryFlag::ACK; - s.handle_query_response(response.clone()).await; - s.handle_query_response(response.clone()).await; + s.handle_query_response(resp_ref).await.unwrap(); + s.handle_query_response(resp_ref).await.unwrap(); + + let mut resp_ref2 = resp_ref; + resp_ref2.flags |= QueryFlag::ACK; + s.handle_query_response(resp_ref2).await.unwrap(); + s.handle_query_response(resp_ref2).await.unwrap(); // Ensure we only get one NodeResponse off the channel let resp_rx = query.response_rx(); diff --git a/serf-core/src/serf/base/tests/serf/join.rs b/serf-core/src/serf/base/tests/serf/join.rs index acbed21..ee229b5 100644 --- a/serf-core/src/serf/base/tests/serf/join.rs +++ b/serf-core/src/serf/base/tests/serf/join.rs @@ -35,10 +35,8 @@ where } /// Unit tests for the join intent old message -pub async fn join_intent_old_message( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn join_intent_old_message(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let opts = test_config(); @@ -87,10 +85,8 @@ pub async fn join_intent_old_message( } /// Unit tests for the join intent newer -pub async fn join_intent_newer( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn join_intent_newer(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let opts = test_config(); @@ -138,10 +134,8 @@ pub async fn join_intent_newer( } /// Unit tests for the join intent reset leaving -pub async fn join_intent_reset_leaving( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn join_intent_reset_leaving(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let opts = test_config(); @@ -270,10 +264,8 @@ where } /// Unit tests for the join pending intent logic -pub async fn join_pending_intent( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn join_pending_intent(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let opts = test_config(); @@ -310,10 +302,8 @@ pub async fn join_pending_intent( } /// Unit tests for the join pending intent logic -pub async fn join_pending_intents( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn join_pending_intents(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let opts = test_config(); diff --git a/serf-core/src/serf/base/tests/serf/leave.rs b/serf-core/src/serf/base/tests/serf/leave.rs index 3b772df..44e7a7e 100644 --- a/serf-core/src/serf/base/tests/serf/leave.rs +++ b/serf-core/src/serf/base/tests/serf/leave.rs @@ -32,10 +32,8 @@ where } /// Unit tests for the leave intent old message -pub async fn leave_intent_old_message( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn leave_intent_old_message(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let opts = test_config(); @@ -84,10 +82,8 @@ pub async fn leave_intent_old_message( } /// Unit tests for the leave intent newer -pub async fn leave_intent_newer( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn leave_intent_newer(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let opts = test_config(); diff --git a/serf-core/src/serf/base/tests/serf/reap.rs b/serf-core/src/serf/base/tests/serf/reap.rs index 669a867..c575971 100644 --- a/serf-core/src/serf/base/tests/serf/reap.rs +++ b/serf-core/src/serf/base/tests/serf/reap.rs @@ -38,10 +38,8 @@ where } /// Unit test for reap handler -pub async fn serf_reap_handler( - opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn serf_reap_handler(opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let s = Serf::::new( diff --git a/serf-core/src/serf/base/tests/serf/reconnect.rs b/serf-core/src/serf/base/tests/serf/reconnect.rs index 5e3dcf4..3e8a508 100644 --- a/serf-core/src/serf/base/tests/serf/reconnect.rs +++ b/serf-core/src/serf/base/tests/serf/reconnect.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use core::{marker::PhantomData, net::SocketAddr}; use memberlist_core::transport::resolver::socket_addr::SocketAddrResolver; @@ -76,7 +76,7 @@ pub async fn serf_reconnect_same_ip( transport2_id: T::Id, get_transport: impl FnOnce(T::Id, T::ResolvedAddress) -> F + Copy, ) where - T: Transport>, + T: Transport>, T::Options: Clone, R: RuntimeLite, F: core::future::Future, diff --git a/serf-core/src/serf/base/tests/serf/snapshot.rs b/serf-core/src/serf/base/tests/serf/snapshot.rs index d29318f..06ae580 100644 --- a/serf-core/src/serf/base/tests/serf/snapshot.rs +++ b/serf-core/src/serf/base/tests/serf/snapshot.rs @@ -3,10 +3,8 @@ use std::io::Read; use super::*; /// Unit test for the snapshoter. -pub async fn snapshoter( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn snapshoter(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let dir = tempfile::tempdir().unwrap(); @@ -16,7 +14,7 @@ pub async fn snapshoter( let clock = LamportClock::new(); let (out_tx, out_rx) = async_channel::bounded(64); let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, false).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, false).unwrap(); let (event_tx, _, handle) = Snapshot::>::from_replay_result( res, SNAPSHOT_SIZE_LIMIT, @@ -162,7 +160,7 @@ pub async fn snapshoter( // Open the snapshoter let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, false).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, false).unwrap(); assert_eq!(res.last_clock, 100.into()); assert_eq!(res.last_event_clock, 42.into()); @@ -184,7 +182,7 @@ pub async fn snapshoter( assert_eq!(alive_nodes.len(), 1); let n = &alive_nodes[0]; assert_eq!(n.id(), "foo"); - assert_eq!(n.address().clone().into_resolved().unwrap(), addr); + assert_eq!(n.address().clone().unwrap_resolved(), addr); // Close the snapshoter shutdown_tx.close(); @@ -193,7 +191,7 @@ pub async fn snapshoter( // Open the snapshoter, make sure nothing dies reading with coordinates // disabled. let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, false).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, false).unwrap(); let (out_tx, _out_rx) = async_channel::bounded(64); let (_event_tx, _, handle) = Snapshot::>::from_replay_result( @@ -212,10 +210,8 @@ pub async fn snapshoter( } /// Unit test for the snapshoter force compact. -pub async fn snapshoter_force_compact( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn snapshoter_force_compact(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let dir = tempfile::tempdir().unwrap(); @@ -226,7 +222,7 @@ pub async fn snapshoter_force_compact( let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); // Create a very low limit - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, false).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, false).unwrap(); let (out_tx, _out_rx) = async_channel::unbounded(); let (event_tx, _, handle) = Snapshot::>::from_replay_result( res, @@ -274,17 +270,15 @@ pub async fn snapshoter_force_compact( handle.wait().await; // Open the snapshoter - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, false).unwrap(); + let res = open_and_replay_snapshot::(&p, false).unwrap(); assert_eq!(res.last_event_clock, 1023.into()); assert_eq!(res.last_query_clock, 1023.into()); } /// Unit test for the snapshoter leave -pub async fn snapshoter_leave( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn snapshoter_leave(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let dir = tempfile::tempdir().unwrap(); @@ -293,7 +287,7 @@ pub async fn snapshoter_leave( let clock = LamportClock::new(); let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, false).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, false).unwrap(); let (out_tx, _out_rx) = async_channel::unbounded(); let (event_tx, _, handle) = Snapshot::>::from_replay_result( res, @@ -359,7 +353,7 @@ pub async fn snapshoter_leave( // Open the snapshoter let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, false).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, false).unwrap(); assert!(res.last_clock == 0.into(), "last_clock: {}", res.last_clock); assert!( res.last_event_clock == 0.into(), @@ -392,10 +386,8 @@ pub async fn snapshoter_leave( } /// Unit test for the snapshoter leave rejoin -pub async fn snapshoter_leave_rejoin( - transport_opts: T::Options, - addr: T::ResolvedAddress, -) where +pub async fn snapshoter_leave_rejoin(transport_opts: T::Options, addr: T::ResolvedAddress) +where T: Transport, { let dir = tempfile::tempdir().unwrap(); @@ -404,7 +396,7 @@ pub async fn snapshoter_leave_rejoin( let clock = LamportClock::new(); let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, true).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, true).unwrap(); let (out_tx, _out_rx) = async_channel::unbounded(); let (event_tx, _, handle) = Snapshot::>::from_replay_result( res, @@ -474,7 +466,7 @@ pub async fn snapshoter_leave_rejoin( // Open the snapshoter let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); - let res = open_and_replay_snapshot::<_, _, DefaultDelegate, _>(&p, true).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, true).unwrap(); assert!(res.last_clock == 100.into()); assert!(res.last_event_clock == 42.into()); assert!(res.last_query_clock == 50.into()); @@ -615,18 +607,12 @@ pub async fn serf_snapshot_recovery( async fn test_snapshoter_slow_disk_not_blocking_event_tx() { use memberlist_core::{ agnostic_lite::tokio::TokioRuntime, - transport::{resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport}, + transport::{resolver::socket_addr::SocketAddrResolver, unimplemented::UnimplementedTransport}, }; - use std::net::SocketAddr; crate::tests::initialize_tests_tracing(); - type Transport = UnimplementedTransport< - SmolStr, - SocketAddrResolver, - Lpe, - TokioRuntime, - >; + type Transport = UnimplementedTransport, TokioRuntime>; type Delegate = DefaultDelegate; @@ -638,7 +624,7 @@ async fn test_snapshoter_slow_disk_not_blocking_event_tx() { let clock = LamportClock::new(); let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); let (out_tx, out_rx) = async_channel::bounded(1024); - let res = open_and_replay_snapshot::<_, _, Delegate, _>(&p, true).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, true).unwrap(); let (event_tx, _, handle) = Snapshot::::from_replay_result( res, SNAPSHOT_SIZE_LIMIT, @@ -703,7 +689,7 @@ async fn test_snapshoter_slow_disk_not_blocking_event_tx() { // 115ms on my machine so this should give plenty of headroom for slower CI // environments while still being low enough that actual disk IO would // reliably blow it. - let deadline = TokioRuntime::sleep_until(std::time::Instant::now() + Duration::from_millis(500)); + let deadline = TokioRuntime::sleep_until(TokioRuntime::now() + Duration::from_millis(500)); futures::pin_mut!(deadline); let mut num_recvd = 0; let start = Epoch::now(); @@ -732,16 +718,10 @@ async fn test_snapshoter_slow_disk_not_blocking_event_tx() { async fn test_snapshoter_slow_disk_not_blocking_memberlist() { use memberlist_core::{ agnostic_lite::tokio::TokioRuntime, - transport::{Lpe, resolver::socket_addr::SocketAddrResolver, tests::UnimplementedTransport}, + transport::{resolver::socket_addr::SocketAddrResolver, unimplemented::UnimplementedTransport}, }; - use std::net::SocketAddr; - type Transport = UnimplementedTransport< - SmolStr, - SocketAddrResolver, - Lpe, - TokioRuntime, - >; + type Transport = UnimplementedTransport, TokioRuntime>; type Delegate = DefaultDelegate; @@ -753,7 +733,7 @@ async fn test_snapshoter_slow_disk_not_blocking_memberlist() { let clock = LamportClock::new(); let (shutdown_tx, shutdown_rx) = async_channel::bounded(1); let (out_tx, _out_rx) = async_channel::bounded(1); - let res = open_and_replay_snapshot::<_, _, Delegate, _>(&p, true).unwrap(); + let res = open_and_replay_snapshot::<_, _, _>(&p, true).unwrap(); let (event_tx, _, handle) = Snapshot::::from_replay_result( res, SNAPSHOT_SIZE_LIMIT, diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index f750da1..6e8cae4 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -6,7 +6,7 @@ use crate::{ event::QueryMessageExt, types::{ DelegateVersion, JoinMessage, LamportTime, LeaveMessage, Member, MemberStatus, MessageRef, - MessageType, ProtocolVersion, PushPullMessageBorrow, UserEventMessage, + ProtocolVersion, PushPullMessageBorrow, UserEventMessage, }, }; @@ -31,6 +31,9 @@ use memberlist_core::{ transport::Transport, }; +#[cfg(any(test, feature = "test"))] +use crate::types::MessageType; + // PingVersion is an internal version for the ping message, above the normal // versioning we get from the protocol version. This enables small updates // to the ping message without a full protocol bump. diff --git a/serf-core/src/types.rs b/serf-core/src/types.rs index 6f0be3e..d79084d 100644 --- a/serf-core/src/types.rs +++ b/serf-core/src/types.rs @@ -5,6 +5,46 @@ pub use memberlist_core::proto::{ ParseHostAddrError, ParseNodeIdError, ProtocolVersion as MemberlistProtocolVersion, }; +#[cfg(feature = "encryption")] +#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] +pub use memberlist_core::proto::encryption::*; + +#[cfg(any( + feature = "crc32", + feature = "xxhash64", + feature = "xxhash32", + feature = "xxhash3", + feature = "murmur3", +))] +#[cfg_attr( + docsrs, + doc(cfg(any( + feature = "crc32", + feature = "xxhash64", + feature = "xxhash32", + feature = "xxhash3", + feature = "murmur3" + ))) +)] +pub use memberlist_core::proto::checksum::*; + +#[cfg(any( + feature = "zstd", + feature = "lz4", + feature = "snappy", + feature = "brotli", +))] +#[cfg_attr( + docsrs, + doc(cfg(any( + feature = "zstd", + feature = "lz4", + feature = "snappy", + feature = "brotli" + ))) +)] +pub use memberlist_core::proto::compression::*; + #[cfg(feature = "arbitrary")] mod arbitrary_impl; diff --git a/serf-core/src/types/filter.rs b/serf-core/src/types/filter.rs index f4c0ed1..1fb00ce 100644 --- a/serf-core/src/types/filter.rs +++ b/serf-core/src/types/filter.rs @@ -57,8 +57,19 @@ impl From for u8 { /// Used with a queryFilter to specify the type of /// filter we are sending -#[derive(Debug, Clone, PartialEq, Eq, derive_more::IsVariant)] +#[derive( + Debug, + Clone, + PartialEq, + Eq, + derive_more::IsVariant, + derive_more::From, + derive_more::Unwrap, + derive_more::TryUnwrap, +)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[unwrap(ref, ref_mut)] +#[try_unwrap(ref, ref_mut)] #[non_exhaustive] pub enum Filter { /// Filter by node ids diff --git a/serf-core/src/types/filter/tag_filter.rs b/serf-core/src/types/filter/tag_filter.rs index c2ecc72..07cef11 100644 --- a/serf-core/src/types/filter/tag_filter.rs +++ b/serf-core/src/types/filter/tag_filter.rs @@ -118,6 +118,13 @@ impl TagFilter { expr: None, } } + + /// Set the expression for the tag filter + #[inline] + pub fn with_expr(mut self, expr: Regex) -> Self { + self.expr = Some(expr); + self + } } impl PartialEq for TagFilter { diff --git a/serf-core/src/types/query/response.rs b/serf-core/src/types/query/response.rs index 633233e..0c9cf3a 100644 --- a/serf-core/src/types/query/response.rs +++ b/serf-core/src/types/query/response.rs @@ -66,16 +66,14 @@ impl QueryResponseMessage { pub fn ack(&self) -> bool { self.flags.contains(QueryFlag::ACK) } - - /// Checks if the no broadcast flag is set - #[inline] - pub fn no_broadcast(&self) -> bool { - self.flags.contains(QueryFlag::NO_BROADCAST) - } } /// The reference type to a query response message -#[viewit::viewit(vis_all = "", getters(vis_all = "pub", style = "ref"), setters(skip))] +#[viewit::viewit( + vis_all = "pub(crate)", + getters(vis_all = "pub", style = "ref"), + setters(skip) +)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub struct QueryResponseMessageRef<'a, I, A> { /// Event lamport time diff --git a/serf/src/async_std.rs b/serf/src/async_std.rs index 0f98dea..a030d4b 100644 --- a/serf/src/async_std.rs +++ b/serf/src/async_std.rs @@ -1,23 +1,16 @@ pub use memberlist::agnostic::async_std::AsyncStdRuntime; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tcp`](memberlist::net::stream_layer::tcp::Tcp) stream layer with `async-std` runtime. -#[cfg(all( - any(feature = "tcp", feature = "tls", feature = "native-tls"), - not(target_family = "wasm") -))] +#[cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))] #[cfg_attr( docsrs, - doc(cfg(all( - any(feature = "tcp", feature = "tls", feature = "native-tls"), - not(target_family = "wasm") - ))) + doc(cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))) )] -pub type AsyncStdTcpSerf = serf_core::Serf< +pub type AsyncStdTcpSerf = serf_core::Serf< memberlist::net::NetTransport< I, A, memberlist::net::stream_layer::tcp::Tcp, - W, AsyncStdRuntime, >, D, @@ -26,29 +19,11 @@ pub type AsyncStdTcpSerf = serf_core::Serf< /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tls`](memberlist::net::stream_layer::tls::Tls) stream layer with `async-std` runtime. #[cfg(all(feature = "tls", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "tls", not(target_family = "wasm")))))] -pub type AsyncStdTlsSerf = serf_core::Serf< +pub type AsyncStdTlsSerf = serf_core::Serf< memberlist::net::NetTransport< I, A, memberlist::net::stream_layer::tls::Tls, - W, - AsyncStdRuntime, - >, - D, ->; - -/// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`NativeTls`](memberlist::net::stream_layer::native_tls::NativeTls) stream layer with `async-std` runtime. -#[cfg(all(feature = "native-tls", not(target_family = "wasm")))] -#[cfg_attr( - docsrs, - doc(cfg(all(feature = "native-tls", not(target_family = "wasm")))) -)] -pub type AsyncStdNativeTlsSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::native_tls::NativeTls, - W, AsyncStdRuntime, >, D, @@ -57,26 +32,11 @@ pub type AsyncStdNativeTlsSerf = serf_core::Serf< /// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`Quinn`](memberlist::quic::stream_layer::quinn::Quinn) stream layer with `async-std` runtime. #[cfg(all(feature = "quinn", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "quinn", not(target_family = "wasm")))))] -pub type AsyncStdQuicSerf = serf_core::Serf< +pub type AsyncStdQuicSerf = serf_core::Serf< memberlist::quic::QuicTransport< I, A, memberlist::quic::stream_layer::quinn::Quinn, - W, - AsyncStdRuntime, - >, - D, ->; - -/// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`S2n`](memberlist::quic::stream_layer::s2n::S2n) stream layer with `async-std` runtime. -#[cfg(all(feature = "s2n", not(target_family = "wasm")))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "s2n", not(target_family = "wasm")))))] -pub type AsyncStdS2nSerf = serf_core::Serf< - memberlist::quic::QuicTransport< - I, - A, - memberlist::quic::stream_layer::s2n::S2n, - W, AsyncStdRuntime, >, D, diff --git a/serf/src/lib.rs b/serf/src/lib.rs index 352ea93..8a1026e 100644 --- a/serf/src/lib.rs +++ b/serf/src/lib.rs @@ -1,32 +1,32 @@ -// #![doc = include_str!("../../README.md")] -// #![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] -// #![forbid(unsafe_code)] -// #![deny(warnings, missing_docs)] -// #![allow(clippy::type_complexity)] -// #![cfg_attr(docsrs, feature(doc_cfg))] -// #![cfg_attr(docsrs, allow(unused_attributes))] +#![doc = include_str!("../../README.md")] +#![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] +#![forbid(unsafe_code)] +#![deny(warnings, missing_docs)] +#![allow(clippy::type_complexity)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, allow(unused_attributes))] -// pub use serf_core::*; +pub use serf_core::*; -// pub use memberlist::{agnostic, transport}; +pub use memberlist::{agnostic, transport}; -// #[cfg(feature = "net")] -// pub use memberlist::net; +#[cfg(feature = "net")] +pub use memberlist::net; -// #[cfg(feature = "quic")] -// pub use memberlist::quic; +#[cfg(feature = "quic")] +pub use memberlist::quic; -// /// [`Serf`](serf_core::Serf) for `tokio` runtime. -// #[cfg(feature = "tokio")] -// #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] -// pub mod tokio; +/// [`Serf`](serf_core::Serf) for `tokio` runtime. +#[cfg(feature = "tokio")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] +pub mod tokio; -// /// [`Serf`](serf_core::Serf) for `async-std` runtime. -// #[cfg(feature = "async-std")] -// #[cfg_attr(docsrs, doc(cfg(feature = "async-std")))] -// pub mod async_std; +/// [`Serf`](serf_core::Serf) for `async-std` runtime. +#[cfg(feature = "async-std")] +#[cfg_attr(docsrs, doc(cfg(feature = "async-std")))] +pub mod async_std; -// /// [`Serf`](serf_core::Serf) for `smol` runtime. -// #[cfg(feature = "smol")] -// #[cfg_attr(docsrs, doc(cfg(feature = "smol")))] -// pub mod smol; +/// [`Serf`](serf_core::Serf) for `smol` runtime. +#[cfg(feature = "smol")] +#[cfg_attr(docsrs, doc(cfg(feature = "smol")))] +pub mod smol; diff --git a/serf/src/smol.rs b/serf/src/smol.rs index 7bad2e0..0c56d9b 100644 --- a/serf/src/smol.rs +++ b/serf/src/smol.rs @@ -1,23 +1,16 @@ pub use memberlist::agnostic::smol::SmolRuntime; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tcp`](memberlist::net::stream_layer::tcp::Tcp) stream layer with `smol` runtime. -#[cfg(all( - any(feature = "tcp", feature = "tls", feature = "native-tls"), - not(target_family = "wasm") -))] +#[cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))] #[cfg_attr( docsrs, - doc(cfg(all( - any(feature = "tcp", feature = "tls", feature = "native-tls"), - not(target_family = "wasm") - ))) + doc(cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))) )] -pub type SmolTcpSerf = serf_core::Serf< +pub type SmolTcpSerf = serf_core::Serf< memberlist::net::NetTransport< I, A, memberlist::net::stream_layer::tcp::Tcp, - W, SmolRuntime, >, D, @@ -26,29 +19,11 @@ pub type SmolTcpSerf = serf_core::Serf< /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tls`](memberlist::net::stream_layer::tls::Tls) stream layer with `smol` runtime. #[cfg(all(feature = "tls", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "tls", not(target_family = "wasm")))))] -pub type SmolTlsSerf = serf_core::Serf< +pub type SmolTlsSerf = serf_core::Serf< memberlist::net::NetTransport< I, A, memberlist::net::stream_layer::tls::Tls, - W, - SmolRuntime, - >, - D, ->; - -/// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`NativeTls`](memberlist::net::stream_layer::native_tls::NativeTls) stream layer with `smol` runtime. -#[cfg(all(feature = "native-tls", not(target_family = "wasm")))] -#[cfg_attr( - docsrs, - doc(cfg(all(feature = "native-tls", not(target_family = "wasm")))) -)] -pub type SmolNativeTlsSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::native_tls::NativeTls, - W, SmolRuntime, >, D, @@ -57,26 +32,11 @@ pub type SmolNativeTlsSerf = serf_core::Serf< /// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`Quinn`](memberlist::quic::stream_layer::quinn::Quinn) stream layer with `smol` runtime. #[cfg(all(feature = "quinn", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "quinn", not(target_family = "wasm")))))] -pub type SmolQuicSerf = serf_core::Serf< +pub type SmolQuicSerf = serf_core::Serf< memberlist::quic::QuicTransport< I, A, memberlist::quic::stream_layer::quinn::Quinn, - W, - SmolRuntime, - >, - D, ->; - -/// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`S2n`](memberlist::quic::stream_layer::s2n::S2n) stream layer with `smol` runtime. -#[cfg(all(feature = "s2n", not(target_family = "wasm")))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "s2n", not(target_family = "wasm")))))] -pub type SmolS2nSerf = serf_core::Serf< - memberlist::quic::QuicTransport< - I, - A, - memberlist::quic::stream_layer::s2n::S2n, - W, SmolRuntime, >, D, diff --git a/serf/src/tokio.rs b/serf/src/tokio.rs index 347afb5..2e4f8e6 100644 --- a/serf/src/tokio.rs +++ b/serf/src/tokio.rs @@ -1,23 +1,16 @@ pub use memberlist::agnostic::tokio::TokioRuntime; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tcp`](memberlist::net::stream_layer::tcp::Tcp) stream layer with `tokio` runtime. -#[cfg(all( - any(feature = "tcp", feature = "tls", feature = "native-tls"), - not(target_family = "wasm") -))] +#[cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))] #[cfg_attr( docsrs, - doc(cfg(all( - any(feature = "tcp", feature = "tls", feature = "native-tls"), - not(target_family = "wasm") - ))) + doc(cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))) )] -pub type TokioTcpSerf = serf_core::Serf< +pub type TokioTcpSerf = serf_core::Serf< memberlist::net::NetTransport< I, A, memberlist::net::stream_layer::tcp::Tcp, - W, TokioRuntime, >, D, @@ -26,29 +19,11 @@ pub type TokioTcpSerf = serf_core::Serf< /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tls`](memberlist::net::stream_layer::tls::Tls) stream layer with `tokio` runtime. #[cfg(all(feature = "tls", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "tls", not(target_family = "wasm")))))] -pub type TokioTlsSerf = serf_core::Serf< +pub type TokioTlsSerf = serf_core::Serf< memberlist::net::NetTransport< I, A, memberlist::net::stream_layer::tls::Tls, - W, - TokioRuntime, - >, - D, ->; - -/// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`NativeTls`](memberlist::net::stream_layer::native_tls::NativeTls) stream layer with `tokio` runtime. -#[cfg(all(feature = "native-tls", not(target_family = "wasm")))] -#[cfg_attr( - docsrs, - doc(cfg(all(feature = "native-tls", not(target_family = "wasm")))) -)] -pub type TokioNativeTlsSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::native_tls::NativeTls, - W, TokioRuntime, >, D, @@ -57,26 +32,11 @@ pub type TokioNativeTlsSerf = serf_core::Serf< /// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`Quinn`](memberlist::quic::stream_layer::quinn::Quinn) stream layer with `tokio` runtime. #[cfg(all(feature = "quinn", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "quinn", not(target_family = "wasm")))))] -pub type TokioQuicSerf = serf_core::Serf< +pub type TokioQuicSerf = serf_core::Serf< memberlist::quic::QuicTransport< I, A, memberlist::quic::stream_layer::quinn::Quinn, - W, - TokioRuntime, - >, - D, ->; - -/// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`S2n`](memberlist::quic::stream_layer::s2n::S2n) stream layer with `tokio` runtime. -#[cfg(all(feature = "s2n", not(target_family = "wasm")))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "s2n", not(target_family = "wasm")))))] -pub type TokioS2nSerf = serf_core::Serf< - memberlist::quic::QuicTransport< - I, - A, - memberlist::quic::stream_layer::s2n::S2n, - W, TokioRuntime, >, D, diff --git a/serf/test/main.rs b/serf/test/main.rs index bde3dd0..c8ed71e 100644 --- a/serf/test/main.rs +++ b/serf/test/main.rs @@ -1,28 +1,28 @@ -// use core::future::Future; -// use serf_core::tests::run as run_unit_test; +use core::future::Future; +use serf_core::tests::run as run_unit_test; -// #[cfg(feature = "net")] -// #[path = "./main/net.rs"] -// mod net; +#[cfg(feature = "net")] +#[path = "./main/net.rs"] +mod net; -// #[cfg(feature = "tokio")] -// fn tokio_run(fut: impl Future) { -// let runtime = ::tokio::runtime::Builder::new_multi_thread() -// .worker_threads(32) -// .enable_all() -// .build() -// .unwrap(); -// run_unit_test(|fut| runtime.block_on(fut), fut) -// } +#[cfg(feature = "tokio")] +fn tokio_run(fut: impl Future) { + let runtime = ::tokio::runtime::Builder::new_multi_thread() + .worker_threads(32) + .enable_all() + .build() + .unwrap(); + run_unit_test(|fut| runtime.block_on(fut), fut) +} -// #[cfg(feature = "smol")] -// fn smol_run(fut: impl Future) { -// use serf::agnostic::{RuntimeLite, smol::SmolRuntime}; -// run_unit_test(SmolRuntime::block_on, fut); -// } +#[cfg(feature = "smol")] +fn smol_run(fut: impl Future) { + use serf::agnostic::{RuntimeLite, smol::SmolRuntime}; + run_unit_test(SmolRuntime::block_on, fut); +} -// #[cfg(feature = "async-std")] -// fn async_std_run(fut: impl Future) { -// use serf::agnostic::{RuntimeLite, async_std::AsyncStdRuntime}; -// run_unit_test(AsyncStdRuntime::block_on, fut); -// } +#[cfg(feature = "async-std")] +fn async_std_run(fut: impl Future) { + use serf::agnostic::{RuntimeLite, async_std::AsyncStdRuntime}; + run_unit_test(AsyncStdRuntime::block_on, fut); +} diff --git a/serf/test/main/net/coordinates.rs b/serf/test/main/net/coordinates.rs index 646a19f..a690850 100644 --- a/serf/test/main/net/coordinates.rs +++ b/serf/test/main/net/coordinates.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_coordinates, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -35,7 +33,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); @@ -60,7 +58,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); diff --git a/serf/test/main/net/delegate/local_state.rs b/serf/test/main/net/delegate/local_state.rs index 5bce1da..a0455cc 100644 --- a/serf/test/main/net/delegate/local_state.rs +++ b/serf/test/main/net/delegate/local_state.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{delegate::delegate_local_state, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/delegate/nodemeta.rs b/serf/test/main/net/delegate/nodemeta.rs index 6547387..392e25a 100644 --- a/serf/test/main/net/delegate/nodemeta.rs +++ b/serf/test/main/net/delegate/nodemeta.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{delegate::{delegate_nodemeta, delegate_nodemeta_panic}, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -26,7 +24,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -42,7 +40,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -59,7 +57,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -76,7 +74,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/delegate/ping_delegate.rs b/serf/test/main/net/delegate/ping_delegate.rs index 7e1e2d8..bb6cce0 100644 --- a/serf/test/main/net/delegate/ping_delegate.rs +++ b/serf/test/main/net/delegate/ping_delegate.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{delegate::{serf_ping_delegate_versioning, serf_ping_delegate_rogue_coordinate}, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -73,7 +71,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -94,7 +92,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/delegate/remote_state.rs b/serf/test/main/net/delegate/remote_state.rs index a9fd77c..8537862 100644 --- a/serf/test/main/net/delegate/remote_state.rs +++ b/serf/test/main/net/delegate/remote_state.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{delegate::delegate_merge_remote_state, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/event.rs b/serf/test/main/net/event.rs index 2a0e0fb..a832f84 100644 --- a/serf/test/main/net/event.rs +++ b/serf/test/main/net/event.rs @@ -28,9 +28,6 @@ mod query_filter; #[path = "./event/query_old_message.rs"] mod query_old_message; -#[path = "./event/query_params_encode_filters.rs"] -mod query_params_encode_filters; - #[path = "./event/query.rs"] mod query; diff --git a/serf/test/main/net/event/default_query.rs b/serf/test/main/net/event/default_query.rs index bf0cd9b..de13aac 100644 --- a/serf/test/main/net/event/default_query.rs +++ b/serf/test/main/net/event/default_query.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::default_query, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/event/event_user.rs b/serf/test/main/net/event/event_user.rs index fd6f32d..098faa1 100644 --- a/serf/test/main/net/event/event_user.rs +++ b/serf/test/main/net/event/event_user.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_event_user, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/event/event_user_size_limit.rs b/serf/test/main/net/event/event_user_size_limit.rs index ec525b8..ffb3bb3 100644 --- a/serf/test/main/net/event/event_user_size_limit.rs +++ b/serf/test/main/net/event/event_user_size_limit.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_event_user_size_limit, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/event/events_failed.rs b/serf/test/main/net/event/events_failed.rs index acc40e0..079d44e 100644 --- a/serf/test/main/net/event/events_failed.rs +++ b/serf/test/main/net/event/events_failed.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_events_failed, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/event/events_join.rs b/serf/test/main/net/event/events_join.rs index 1b126b4..c9a20af 100644 --- a/serf/test/main/net/event/events_join.rs +++ b/serf/test/main/net/event/events_join.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_events_join, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/event/events_leave.rs b/serf/test/main/net/event/events_leave.rs index b537609..56357a5 100644 --- a/serf/test/main/net/event/events_leave.rs +++ b/serf/test/main/net/event/events_leave.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_events_leave, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/event/events_leave_avoid_infinite_rebroadcast.rs b/serf/test/main/net/event/events_leave_avoid_infinite_rebroadcast.rs index 1089030..dc7ef3b 100644 --- a/serf/test/main/net/event/events_leave_avoid_infinite_rebroadcast.rs +++ b/serf/test/main/net/event/events_leave_avoid_infinite_rebroadcast.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_events_leave_avoid_infinite_rebroadcast, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -39,7 +37,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, @@ -73,7 +71,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, diff --git a/serf/test/main/net/event/query.rs b/serf/test/main/net/event/query.rs index fbf161f..1112993 100644 --- a/serf/test/main/net/event/query.rs +++ b/serf/test/main/net/event/query.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_query, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/event/query_deduplicate.rs b/serf/test/main/net/event/query_deduplicate.rs index 7c125da..0f0b23e 100644 --- a/serf/test/main/net/event/query_deduplicate.rs +++ b/serf/test/main/net/event/query_deduplicate.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_query_deduplicate, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/event/query_filter.rs b/serf/test/main/net/event/query_filter.rs index 4aa04c3..7055f4b 100644 --- a/serf/test/main/net/event/query_filter.rs +++ b/serf/test/main/net/event/query_filter.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_query_filter, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -35,7 +33,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); @@ -60,7 +58,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); diff --git a/serf/test/main/net/event/query_old_message.rs b/serf/test/main/net/event/query_old_message.rs index 1528285..69e3bbf 100644 --- a/serf/test/main/net/event/query_old_message.rs +++ b/serf/test/main/net/event/query_old_message.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::{Lpe, Node}, + transport::Node, }; use serf_core::tests::{event::query_old_message, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, Node::new("fake1".into(), next_socket_addr_v4(0)))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, Node::new("fake1".into(), next_socket_addr_v4(0)))); diff --git a/serf/test/main/net/event/query_params_encode_filters.rs b/serf/test/main/net/event/query_params_encode_filters.rs deleted file mode 100644 index aa695c0..0000000 --- a/serf/test/main/net/event/query_params_encode_filters.rs +++ /dev/null @@ -1,64 +0,0 @@ -macro_rules! test_mod { - ($rt:ident) => { - paste::paste! { - mod [< $rt:snake >] { - use std::net::SocketAddr; - - use crate::[< $rt:snake _run >]; - use serf::{ - net::{ - resolver::socket_addr::SocketAddrResolver, stream_layer::tcp::Tcp, NetTransport, - NetTransportOptions, - }, - [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, - }; - use serf_core::tests::{event::query_params_encode_filters, next_socket_addr_v4, next_socket_addr_v6}; - use smol_str::SmolStr; - - #[test] - fn test_query_params_encode_filters_v4() { - let name = "query_params_encode_filters_v4"; - let mut opts = NetTransportOptions::new(SmolStr::new(name)); - opts.add_bind_address(next_socket_addr_v4(0)); - - [< $rt:snake _run >](query_params_encode_filters::< - NetTransport< - SmolStr, - SocketAddrResolver<[< $rt:camel Runtime >]>, - Tcp<[< $rt:camel Runtime >]>, - Lpe, - [< $rt:camel Runtime >], - >, - >(opts)); - } - - #[test] - fn test_query_params_encode_filters_v6() { - let name = "query_params_encode_filters_v6"; - let mut opts = NetTransportOptions::new(SmolStr::new(name)); - opts.add_bind_address(next_socket_addr_v6()); - - [< $rt:snake _run >](query_params_encode_filters::< - NetTransport< - SmolStr, - SocketAddrResolver<[< $rt:camel Runtime >]>, - Tcp<[< $rt:camel Runtime >]>, - Lpe, - [< $rt:camel Runtime >], - >, - >(opts)); - } - } - } - }; -} - -#[cfg(feature = "tokio")] -test_mod!(tokio); - -#[cfg(feature = "async-std")] -test_mod!(async_std); - -#[cfg(feature = "smol")] -test_mod!(smol); diff --git a/serf/test/main/net/event/query_same_clock.rs b/serf/test/main/net/event/query_same_clock.rs index 5d50c47..8f40fdc 100644 --- a/serf/test/main/net/event/query_same_clock.rs +++ b/serf/test/main/net/event/query_same_clock.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::{Lpe, Node}, + transport::Node, }; use serf_core::tests::{event::query_same_clock, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, Node::new("fake1".into(), next_socket_addr_v4(0)))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, Node::new("fake1".into(), next_socket_addr_v4(0)))); diff --git a/serf/test/main/net/event/query_size_limit.rs b/serf/test/main/net/event/query_size_limit.rs index ed8ba33..a688bac 100644 --- a/serf/test/main/net/event/query_size_limit.rs +++ b/serf/test/main/net/event/query_size_limit.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_query_size_limit, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/event/query_size_limit_increased.rs b/serf/test/main/net/event/query_size_limit_increased.rs index 21755d7..a447c16 100644 --- a/serf/test/main/net/event/query_size_limit_increased.rs +++ b/serf/test/main/net/event/query_size_limit_increased.rs @@ -2,7 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; use crate::[< $rt:snake _run >]; use serf::{ @@ -11,7 +10,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_query_size_limit_increased, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +26,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +43,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/event/remove_failed_events_leave.rs b/serf/test/main/net/event/remove_failed_events_leave.rs index 8fea390..484f429 100644 --- a/serf/test/main/net/event/remove_failed_events_leave.rs +++ b/serf/test/main/net/event/remove_failed_events_leave.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::serf_remove_failed_events_leave, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/event/should_process.rs b/serf/test/main/net/event/should_process.rs index 573dc9a..e4ca17a 100644 --- a/serf/test/main/net/event/should_process.rs +++ b/serf/test/main/net/event/should_process.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::should_process, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/event/user_event_old_message.rs b/serf/test/main/net/event/user_event_old_message.rs index 3280157..4c5e304 100644 --- a/serf/test/main/net/event/user_event_old_message.rs +++ b/serf/test/main/net/event/user_event_old_message.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::user_event_old_message, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/event/user_event_same_clock.rs b/serf/test/main/net/event/user_event_same_clock.rs index d06370d..7a514fd 100644 --- a/serf/test/main/net/event/user_event_same_clock.rs +++ b/serf/test/main/net/event/user_event_same_clock.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{event::user_event_same_clock, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/get_queue_max.rs b/serf/test/main/net/get_queue_max.rs index 0911308..57a86a5 100644 --- a/serf/test/main/net/get_queue_max.rs +++ b/serf/test/main/net/get_queue_max.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_get_queue_max, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, |idx| { @@ -46,7 +44,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, |idx| { diff --git a/serf/test/main/net/join/intent_buffer_early.rs b/serf/test/main/net/join/intent_buffer_early.rs index e1fa934..b748370 100644 --- a/serf/test/main/net/join/intent_buffer_early.rs +++ b/serf/test/main/net/join/intent_buffer_early.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::join_intent_buffer_early, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/join/intent_newer.rs b/serf/test/main/net/join/intent_newer.rs index 2e0d01c..13720e5 100644 --- a/serf/test/main/net/join/intent_newer.rs +++ b/serf/test/main/net/join/intent_newer.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::join_intent_newer, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/join/intent_old_message.rs b/serf/test/main/net/join/intent_old_message.rs index d28914d..442e155 100644 --- a/serf/test/main/net/join/intent_old_message.rs +++ b/serf/test/main/net/join/intent_old_message.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::join_intent_old_message, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/join/intent_reset_leaving.rs b/serf/test/main/net/join/intent_reset_leaving.rs index b5258f0..44346aa 100644 --- a/serf/test/main/net/join/intent_reset_leaving.rs +++ b/serf/test/main/net/join/intent_reset_leaving.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::join_intent_reset_leaving, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/join/join_cancel.rs b/serf/test/main/net/join/join_cancel.rs index 50aa93b..5fb1e6f 100644 --- a/serf/test/main/net/join/join_cancel.rs +++ b/serf/test/main/net/join/join_cancel.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::serf_join_cancel, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/join/join_ignore_old.rs b/serf/test/main/net/join/join_ignore_old.rs index 4442e66..95cf606 100644 --- a/serf/test/main/net/join/join_ignore_old.rs +++ b/serf/test/main/net/join/join_ignore_old.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::serf_join_ignore_old, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/join/join_leave.rs b/serf/test/main/net/join/join_leave.rs index a169fbb..135bb5a 100644 --- a/serf/test/main/net/join/join_leave.rs +++ b/serf/test/main/net/join/join_leave.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::serf_join_leave, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/join/join_leave_join.rs b/serf/test/main/net/join/join_leave_join.rs index 66a06ed..2b09fdf 100644 --- a/serf/test/main/net/join/join_leave_join.rs +++ b/serf/test/main/net/join/join_leave_join.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::serf_join_leave_join, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/join/leave_ltime.rs b/serf/test/main/net/join/leave_ltime.rs index a2d21ca..0bf7e3f 100644 --- a/serf/test/main/net/join/leave_ltime.rs +++ b/serf/test/main/net/join/leave_ltime.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::join_leave_ltime, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/join/pending_intent.rs b/serf/test/main/net/join/pending_intent.rs index 62b31ff..ec40495 100644 --- a/serf/test/main/net/join/pending_intent.rs +++ b/serf/test/main/net/join/pending_intent.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::join_pending_intent, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/join/pending_intents.rs b/serf/test/main/net/join/pending_intents.rs index b20d990..7a4d755 100644 --- a/serf/test/main/net/join/pending_intents.rs +++ b/serf/test/main/net/join/pending_intents.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{join::join_pending_intents, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/leave/force_leave_failed.rs b/serf/test/main/net/leave/force_leave_failed.rs index 7746dc9..7b0ba5c 100644 --- a/serf/test/main/net/leave/force_leave_failed.rs +++ b/serf/test/main/net/leave/force_leave_failed.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{leave::serf_force_leave_failed, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -35,7 +33,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); @@ -60,7 +58,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); diff --git a/serf/test/main/net/leave/force_leave_leaving.rs b/serf/test/main/net/leave/force_leave_leaving.rs index 850859e..d202594 100644 --- a/serf/test/main/net/leave/force_leave_leaving.rs +++ b/serf/test/main/net/leave/force_leave_leaving.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{leave::serf_force_leave_leaving, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -35,7 +33,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); @@ -60,7 +58,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); diff --git a/serf/test/main/net/leave/force_leave_left.rs b/serf/test/main/net/leave/force_leave_left.rs index c2a05f2..bdba857 100644 --- a/serf/test/main/net/leave/force_leave_left.rs +++ b/serf/test/main/net/leave/force_leave_left.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{leave::serf_force_leave_left, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -35,7 +33,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); @@ -60,7 +58,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); diff --git a/serf/test/main/net/leave/intent_buffer_early.rs b/serf/test/main/net/leave/intent_buffer_early.rs index 19c833c..f05c6f0 100644 --- a/serf/test/main/net/leave/intent_buffer_early.rs +++ b/serf/test/main/net/leave/intent_buffer_early.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{leave::leave_intent_buffer_early, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/leave/intent_newer.rs b/serf/test/main/net/leave/intent_newer.rs index b9d008b..65523f5 100644 --- a/serf/test/main/net/leave/intent_newer.rs +++ b/serf/test/main/net/leave/intent_newer.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{leave::leave_intent_newer, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/leave/intent_old_message.rs b/serf/test/main/net/leave/intent_old_message.rs index 337bfa5..341b2d9 100644 --- a/serf/test/main/net/leave/intent_old_message.rs +++ b/serf/test/main/net/leave/intent_old_message.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{leave::leave_intent_old_message, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/leave/rejoin_different_role.rs b/serf/test/main/net/leave/rejoin_different_role.rs index dcc25f0..30fbdd1 100644 --- a/serf/test/main/net/leave/rejoin_different_role.rs +++ b/serf/test/main/net/leave/rejoin_different_role.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{leave::serf_leave_rejoin_different_role, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/leave/snapshot_recovery.rs b/serf/test/main/net/leave/snapshot_recovery.rs index c2a1f23..3ab7ce0 100644 --- a/serf/test/main/net/leave/snapshot_recovery.rs +++ b/serf/test/main/net/leave/snapshot_recovery.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{leave::serf_leave_snapshot_recovery, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _ @@ -57,7 +55,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _ diff --git a/serf/test/main/net/local_member.rs b/serf/test/main/net/local_member.rs index 8e0dd75..4e10a52 100644 --- a/serf/test/main/net/local_member.rs +++ b/serf/test/main/net/local_member.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_local_member, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/name_resolution.rs b/serf/test/main/net/name_resolution.rs index fc60456..7aa0d45 100644 --- a/serf/test/main/net/name_resolution.rs +++ b/serf/test/main/net/name_resolution.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_name_resolution, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -35,7 +33,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3, |opts, id| opts.with_id(id))); @@ -60,7 +58,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3, |opts, id| opts.with_id(id))); diff --git a/serf/test/main/net/num_nodes.rs b/serf/test/main/net/num_nodes.rs index 577d41e..f9c5b6d 100644 --- a/serf/test/main/net/num_nodes.rs +++ b/serf/test/main/net/num_nodes.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_num_nodes, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/reap/handler.rs b/serf/test/main/net/reap/handler.rs index 2938b04..dc388e3 100644 --- a/serf/test/main/net/reap/handler.rs +++ b/serf/test/main/net/reap/handler.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{reap::serf_reap_handler, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/reap/handler_shutdown.rs b/serf/test/main/net/reap/handler_shutdown.rs index 50c9dfd..8b9ad0b 100644 --- a/serf/test/main/net/reap/handler_shutdown.rs +++ b/serf/test/main/net/reap/handler_shutdown.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{reap::serf_reap_handler_shutdown, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/reap/reap.rs b/serf/test/main/net/reap/reap.rs index 7a8cd75..c3374fe 100644 --- a/serf/test/main/net/reap/reap.rs +++ b/serf/test/main/net/reap/reap.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{reap::serf_reap, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/reconnect/reconnect.rs b/serf/test/main/net/reconnect/reconnect.rs index 159b764..cfb6a8c 100644 --- a/serf/test/main/net/reconnect/reconnect.rs +++ b/serf/test/main/net/reconnect/reconnect.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{reconnect::serf_reconnect, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, @@ -57,7 +55,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, diff --git a/serf/test/main/net/reconnect/same_ip.rs b/serf/test/main/net/reconnect/same_ip.rs index 8905b25..56dfcee 100644 --- a/serf/test/main/net/reconnect/same_ip.rs +++ b/serf/test/main/net/reconnect/same_ip.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{reconnect::serf_reconnect_same_ip, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -29,7 +27,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, @@ -54,7 +52,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, diff --git a/serf/test/main/net/reconnect/timeout.rs b/serf/test/main/net/reconnect/timeout.rs index efc2acf..15ce047 100644 --- a/serf/test/main/net/reconnect/timeout.rs +++ b/serf/test/main/net/reconnect/timeout.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{reconnect::serf_per_node_reconnect_timeout, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/remove/failed_node.rs b/serf/test/main/net/remove/failed_node.rs index 8e8cf20..03bdc73 100644 --- a/serf/test/main/net/remove/failed_node.rs +++ b/serf/test/main/net/remove/failed_node.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{remove::serf_remove_failed_node, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -35,7 +33,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); @@ -60,7 +58,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); diff --git a/serf/test/main/net/remove/failed_node_ourself.rs b/serf/test/main/net/remove/failed_node_ourself.rs index e0f7b76..165d246 100644 --- a/serf/test/main/net/remove/failed_node_ourself.rs +++ b/serf/test/main/net/remove/failed_node_ourself.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{remove::serf_remove_failed_node_ourself, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/remove/failed_node_prune.rs b/serf/test/main/net/remove/failed_node_prune.rs index b87939d..9252849 100644 --- a/serf/test/main/net/remove/failed_node_prune.rs +++ b/serf/test/main/net/remove/failed_node_prune.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{remove::serf_remove_failed_node_prune, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -35,7 +33,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); @@ -60,7 +58,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2, opts3)); diff --git a/serf/test/main/net/role.rs b/serf/test/main/net/role.rs index 28ad109..bac8afc 100644 --- a/serf/test/main/net/role.rs +++ b/serf/test/main/net/role.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_role, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/set_tags.rs b/serf/test/main/net/set_tags.rs index a776247..5997f7f 100644 --- a/serf/test/main/net/set_tags.rs +++ b/serf/test/main/net/set_tags.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_set_tags, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); @@ -52,7 +50,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, opts2)); diff --git a/serf/test/main/net/snapshot/snapshoter.rs b/serf/test/main/net/snapshot/snapshoter.rs index 1eef5e2..e13d3c6 100644 --- a/serf/test/main/net/snapshot/snapshoter.rs +++ b/serf/test/main/net/snapshot/snapshoter.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{snapshot::snapshoter, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/snapshot/snapshoter_force_compact.rs b/serf/test/main/net/snapshot/snapshoter_force_compact.rs index 677a4e8..2218cb5 100644 --- a/serf/test/main/net/snapshot/snapshoter_force_compact.rs +++ b/serf/test/main/net/snapshot/snapshoter_force_compact.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{snapshot::snapshoter_force_compact, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/snapshot/snapshoter_leave.rs b/serf/test/main/net/snapshot/snapshoter_leave.rs index 288afd6..fb1edcb 100644 --- a/serf/test/main/net/snapshot/snapshoter_leave.rs +++ b/serf/test/main/net/snapshot/snapshoter_leave.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{snapshot::snapshoter_leave, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/snapshot/snapshoter_leave_rejoin.rs b/serf/test/main/net/snapshot/snapshoter_leave_rejoin.rs index 4749dca..221d8bb 100644 --- a/serf/test/main/net/snapshot/snapshoter_leave_rejoin.rs +++ b/serf/test/main/net/snapshot/snapshoter_leave_rejoin.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{snapshot::snapshoter_leave_rejoin, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v4(0))); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts, next_socket_addr_v6())); diff --git a/serf/test/main/net/snapshot/snapshoter_recovery.rs b/serf/test/main/net/snapshot/snapshoter_recovery.rs index 2a83808..c7aa5ce 100644 --- a/serf/test/main/net/snapshot/snapshoter_recovery.rs +++ b/serf/test/main/net/snapshot/snapshoter_recovery.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{snapshot::serf_snapshot_recovery, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _ @@ -57,7 +55,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, diff --git a/serf/test/main/net/state.rs b/serf/test/main/net/state.rs index 09342cc..7c9e48a 100644 --- a/serf/test/main/net/state.rs +++ b/serf/test/main/net/state.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_state, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/stats.rs b/serf/test/main/net/stats.rs index 4829426..faca638 100644 --- a/serf/test/main/net/stats.rs +++ b/serf/test/main/net/stats.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_stats, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -27,7 +25,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); @@ -44,7 +42,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(opts)); diff --git a/serf/test/main/net/update.rs b/serf/test/main/net/update.rs index d8f4981..f6213bc 100644 --- a/serf/test/main/net/update.rs +++ b/serf/test/main/net/update.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,7 +9,7 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; use serf_core::tests::{serf_update, next_socket_addr_v4, next_socket_addr_v6}; use smol_str::SmolStr; @@ -31,7 +29,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, @@ -57,7 +55,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, _, diff --git a/serf/test/main/net/write_keyring_file.rs b/serf/test/main/net/write_keyring_file.rs index 9a61b1e..55d8d24 100644 --- a/serf/test/main/net/write_keyring_file.rs +++ b/serf/test/main/net/write_keyring_file.rs @@ -2,8 +2,6 @@ macro_rules! test_mod { ($rt:ident) => { paste::paste! { mod [< $rt:snake >] { - use std::net::SocketAddr; - use crate::[< $rt:snake _run >]; use serf::{ net::{ @@ -11,9 +9,9 @@ macro_rules! test_mod { NetTransportOptions, }, [< $rt:snake >]::[< $rt:camel Runtime >], - transport::Lpe, + }; - use serf_core::tests::{serf_write_keyring_file, next_socket_addr_v4, next_socket_addr_v6}; + use serf_core::{tests::{serf_write_keyring_file, next_socket_addr_v4, next_socket_addr_v6}, MemberlistOptions, types::EncryptionAlgorithm}; use smol_str::SmolStr; #[test] @@ -27,10 +25,12 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, - >(|kr| opts.with_primary_key(Some(kr)).with_gossip_verify_outgoing(true).with_encryption_algo(Some(serf::net::security::EncryptionAlgo::default())))); + >(|kr| { + (opts, MemberlistOptions::lan().with_primary_key(Some(kr)).with_gossip_verify_outgoing(true).with_encryption_algo(Some(EncryptionAlgorithm::default()))) + })); } #[test] @@ -44,7 +44,7 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - Lpe, + [< $rt:camel Runtime >], >, >(|kr| opts.with_primary_key(Some(kr)).with_gossip_verify_outgoing(true).with_encryption_algo(Some(serf::net::security::EncryptionAlgo::default())))); From e65a099fbcea2e21036ec046486e66b957e47604 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 00:46:09 +0800 Subject: [PATCH 15/39] WIP: fix unit tests --- serf-core/src/lib.rs | 6 +++--- serf-core/src/types/conflict.rs | 3 ++- serf-core/src/types/coordinate.rs | 3 ++- serf-core/src/types/filter.rs | 3 ++- serf-core/src/types/filter/tag_filter.rs | 3 ++- serf-core/src/types/join.rs | 3 ++- serf-core/src/types/key.rs | 6 ++++-- serf-core/src/types/leave.rs | 3 ++- serf-core/src/types/member.rs | 3 ++- serf-core/src/types/message.rs | 6 ++++-- serf-core/src/types/push_pull.rs | 3 ++- serf-core/src/types/query.rs | 3 ++- serf-core/src/types/query/response.rs | 3 ++- serf-core/src/types/tags.rs | 3 ++- serf-core/src/types/user_event.rs | 3 ++- serf-core/src/types/user_event/message.rs | 3 ++- serf-core/src/types/user_event/user_events.rs | 3 ++- serf/test/main/net/write_keyring_file.rs | 4 +++- 18 files changed, 42 insertions(+), 22 deletions(-) diff --git a/serf-core/src/lib.rs b/serf-core/src/lib.rs index ddd2038..79e2720 100644 --- a/serf-core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -97,8 +97,8 @@ pub mod tests { use std::sync::Once; static TRACE: Once = Once::new(); TRACE.call_once(|| { - let filter = std::env::var("RUSERF_TESTING_LOG") - .unwrap_or_else(|_| "serf_core=info,memberlist_core=debug".to_owned()); + let filter = std::env::var("SERF_TESTING_LOG") + .unwrap_or_else(|_| "serf_core=debug,memberlist_core=debug".to_owned()); memberlist_core::tracing::subscriber::set_global_default( tracing_subscriber::fmt::fmt() .without_time() @@ -119,7 +119,7 @@ pub mod tests { B: FnOnce(F) -> F::Output, F: std::future::Future, { - // initialize_tests_tracing(); + initialize_tests_tracing(); block_on(fut); } } diff --git a/serf-core/src/types/conflict.rs b/serf-core/src/types/conflict.rs index 5bb686c..70d4036 100644 --- a/serf-core/src/types/conflict.rs +++ b/serf-core/src/types/conflict.rs @@ -127,7 +127,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("ConflictResponseMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/coordinate.rs b/serf-core/src/types/coordinate.rs index 91dbc72..d765ad3 100644 --- a/serf-core/src/types/coordinate.rs +++ b/serf-core/src/types/coordinate.rs @@ -524,7 +524,8 @@ impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("Coordinate", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/filter.rs b/serf-core/src/types/filter.rs index 1fb00ce..49e478d 100644 --- a/serf-core/src/types/filter.rs +++ b/serf-core/src/types/filter.rs @@ -155,7 +155,8 @@ where } b => { let (wire_type, _) = split(b); - let wt = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wt = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("Filter", v))?; offset += 1; offset += skip(wt, &buf[offset..])?; } diff --git a/serf-core/src/types/filter/tag_filter.rs b/serf-core/src/types/filter/tag_filter.rs index 07cef11..eebbb7a 100644 --- a/serf-core/src/types/filter/tag_filter.rs +++ b/serf-core/src/types/filter/tag_filter.rs @@ -58,7 +58,8 @@ impl<'a> DataRef<'a, TagFilter> for TagFilterRef<'a> { offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("TagFilter", v))?; offset += skip(wire_type, &src[offset..])?; } } diff --git a/serf-core/src/types/join.rs b/serf-core/src/types/join.rs index d00fd54..c8a38a1 100644 --- a/serf-core/src/types/join.rs +++ b/serf-core/src/types/join.rs @@ -90,7 +90,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("JoinMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/key.rs b/serf-core/src/types/key.rs index 4247e51..65621bb 100644 --- a/serf-core/src/types/key.rs +++ b/serf-core/src/types/key.rs @@ -48,7 +48,8 @@ impl DataRef<'_, Self> for KeyRequestMessage { offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("KeyRequestMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } @@ -263,7 +264,8 @@ impl<'a> DataRef<'a, KeyResponseMessage> for KeyResponseMessageRef<'a> { other => { offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("KeyResponseMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/leave.rs b/serf-core/src/types/leave.rs index 2fdeee0..d14cd34 100644 --- a/serf-core/src/types/leave.rs +++ b/serf-core/src/types/leave.rs @@ -107,7 +107,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("LeaveMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/member.rs b/serf-core/src/types/member.rs index b090ae8..09edf68 100644 --- a/serf-core/src/types/member.rs +++ b/serf-core/src/types/member.rs @@ -376,7 +376,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("Member", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/message.rs b/serf-core/src/types/message.rs index 915ec2a..a23c58f 100644 --- a/serf-core/src/types/message.rs +++ b/serf-core/src/types/message.rs @@ -672,7 +672,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("Message", v))?; offset += skip(wire_type, &buf[offset..])?; } } @@ -742,7 +743,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("RelayMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/push_pull.rs b/serf-core/src/types/push_pull.rs index 20f0291..2b6f743 100644 --- a/serf-core/src/types/push_pull.rs +++ b/serf-core/src/types/push_pull.rs @@ -284,7 +284,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("PushPullMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/query.rs b/serf-core/src/types/query.rs index d899594..76134a5 100644 --- a/serf-core/src/types/query.rs +++ b/serf-core/src/types/query.rs @@ -346,7 +346,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("QueryMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/query/response.rs b/serf-core/src/types/query/response.rs index 0c9cf3a..b628d38 100644 --- a/serf-core/src/types/query/response.rs +++ b/serf-core/src/types/query/response.rs @@ -195,7 +195,8 @@ where offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("QueryResponseMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/tags.rs b/serf-core/src/types/tags.rs index cae7a42..954c936 100644 --- a/serf-core/src/types/tags.rs +++ b/serf-core/src/types/tags.rs @@ -99,7 +99,8 @@ impl<'a> DataRef<'a, Tags> for TagsRef<'a> { offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = + WireType::try_from(wire_type).map_err(|v| DecodeError::unknown_wire_type("Tags", v))?; offset += skip(wire_type, &src[offset..])?; } } diff --git a/serf-core/src/types/user_event.rs b/serf-core/src/types/user_event.rs index ef63bd8..2edc859 100644 --- a/serf-core/src/types/user_event.rs +++ b/serf-core/src/types/user_event.rs @@ -88,7 +88,8 @@ impl<'a> DataRef<'a, UserEvent> for UserEventRef<'a> { offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("UserEvent", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/user_event/message.rs b/serf-core/src/types/user_event/message.rs index cc6bebf..5b50b02 100644 --- a/serf-core/src/types/user_event/message.rs +++ b/serf-core/src/types/user_event/message.rs @@ -175,7 +175,8 @@ impl<'a> DataRef<'a, UserEventMessage> for UserEventMessageRef<'a> { offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("UserEventMessage", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf-core/src/types/user_event/user_events.rs b/serf-core/src/types/user_event/user_events.rs index 9edd6f5..34f0577 100644 --- a/serf-core/src/types/user_event/user_events.rs +++ b/serf-core/src/types/user_event/user_events.rs @@ -98,7 +98,8 @@ impl<'a> DataRef<'a, UserEvents> for UserEventsRef<'a> { offset += 1; let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type).map_err(DecodeError::unknown_wire_type)?; + let wire_type = WireType::try_from(wire_type) + .map_err(|v| DecodeError::unknown_wire_type("UserEvents", v))?; offset += skip(wire_type, &buf[offset..])?; } } diff --git a/serf/test/main/net/write_keyring_file.rs b/serf/test/main/net/write_keyring_file.rs index 55d8d24..dea2559 100644 --- a/serf/test/main/net/write_keyring_file.rs +++ b/serf/test/main/net/write_keyring_file.rs @@ -47,7 +47,9 @@ macro_rules! test_mod { [< $rt:camel Runtime >], >, - >(|kr| opts.with_primary_key(Some(kr)).with_gossip_verify_outgoing(true).with_encryption_algo(Some(serf::net::security::EncryptionAlgo::default())))); + >(|kr| { + (opts, MemberlistOptions::lan().with_primary_key(Some(kr)).with_gossip_verify_outgoing(true).with_encryption_algo(Some(EncryptionAlgorithm::default()))) + })); } } } From 2cbf13d24cb6998dbe96a10cf1bd6286c9e6d769 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 01:40:39 +0800 Subject: [PATCH 16/39] WIP --- serf-core/src/serf/delegate.rs | 7 ++++++- serf-core/src/types/push_pull.rs | 7 +------ serf-core/src/types/tests.rs | 8 ++++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index 6e8cae4..622b7ad 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -408,7 +408,10 @@ where drop(members); match crate::types::encode_message_to_bytes(&pp) { - Ok(buf) => buf, + Ok(buf) => { + tracing::debug!(data=?buf.as_ref(), "serf: local state"); + buf + }, Err(e) => { tracing::error!(err=%e, "serf: failed to encode local state"); Bytes::new() @@ -422,6 +425,8 @@ where return; } + tracing::debug!(data=?buf, "serf: merge remote state"); + // Check the message type let msg = match crate::types::decode_message::(buf) { Ok(msg) => msg, diff --git a/serf-core/src/types/push_pull.rs b/serf-core/src/types/push_pull.rs index 2b6f743..d570d3c 100644 --- a/serf-core/src/types/push_pull.rs +++ b/serf-core/src/types/push_pull.rs @@ -508,8 +508,7 @@ where .map(|id| 1 + id.encoded_len_with_length_delimited()) .sum::(); len += 1 + self.event_ltime.encoded_len(); - len += 1 - + self + len += self .events .iter() .filter_map(|e| { @@ -542,10 +541,6 @@ where offset += 1; offset += self.ltime.encode(&mut buf[offset..])?; - bail!(self(offset, buf_len)); - buf[offset] = STATUS_LTIMES_BYTE; - offset += 1; - self .status_ltimes .iter() diff --git a/serf-core/src/types/tests.rs b/serf-core/src/types/tests.rs index a09b147..8401d25 100644 --- a/serf-core/src/types/tests.rs +++ b/serf-core/src/types/tests.rs @@ -465,3 +465,11 @@ encodable_round_trip!( , , ); + +#[test] +fn test() { + let data = [19, 33, 9, 1, 18, 18, 24, 17, 19, 115, 101, 114, 102, 95, 106, 111, 105, 110, 95, 108, 101, 97, 118, 101, 49, 95, 118, 52, 18, 1, 0, 12, 1, 14, 1]; + + let msg = super::decode_message::(&data).unwrap(); + // println!("{:?}", msg); +} From b3e83cc81d044e38676ed650351289ab5d4fd4a4 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 15:48:24 +0800 Subject: [PATCH 17/39] WIP --- serf-core/src/lib.rs | 4 +- serf-core/src/options.rs | 2 + serf-core/src/serf/delegate.rs | 12 ++-- serf-core/src/types/arbitrary_impl.rs | 23 ++++++++ serf-core/src/types/conflict.rs | 11 +--- serf-core/src/types/coordinate.rs | 32 +++++------ serf-core/src/types/filter.rs | 17 ++---- serf-core/src/types/filter/tag_filter.rs | 11 +--- serf-core/src/types/join.rs | 11 +--- serf-core/src/types/key.rs | 26 ++------- serf-core/src/types/leave.rs | 11 +--- serf-core/src/types/member.rs | 11 +--- serf-core/src/types/message.rs | 55 ++++++++++--------- serf-core/src/types/push_pull.rs | 46 ++++++---------- serf-core/src/types/query.rs | 18 ++---- serf-core/src/types/query/response.rs | 11 +--- serf-core/src/types/quickcheck_impl.rs | 20 +++++-- serf-core/src/types/tags.rs | 19 ++----- serf-core/src/types/tests.rs | 27 +++++++-- serf-core/src/types/user_event.rs | 11 +--- serf-core/src/types/user_event/message.rs | 11 +--- serf-core/src/types/user_event/user_events.rs | 19 ++----- serf/test/main/net/coordinates.rs | 40 +++++++------- 23 files changed, 194 insertions(+), 254 deletions(-) diff --git a/serf-core/src/lib.rs b/serf-core/src/lib.rs index 79e2720..092d41c 100644 --- a/serf-core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -98,7 +98,7 @@ pub mod tests { static TRACE: Once = Once::new(); TRACE.call_once(|| { let filter = std::env::var("SERF_TESTING_LOG") - .unwrap_or_else(|_| "serf_core=debug,memberlist_core=debug".to_owned()); + .unwrap_or_else(|_| "serf_core=debug,memberlist_core=info".to_owned()); memberlist_core::tracing::subscriber::set_global_default( tracing_subscriber::fmt::fmt() .without_time() @@ -119,7 +119,7 @@ pub mod tests { B: FnOnce(F) -> F::Output, F: std::future::Future, { - initialize_tests_tracing(); + // initialize_tests_tracing(); block_on(fut); } } diff --git a/serf-core/src/options.rs b/serf-core/src/options.rs index b59c8bc..167bf18 100644 --- a/serf-core/src/options.rs +++ b/serf-core/src/options.rs @@ -480,6 +480,7 @@ impl Clone for Options { fn clone(&self) -> Self { Self { memberlist_options: self.memberlist_options.clone(), + #[cfg(feature = "encryption")] keyring_file: self.keyring_file.clone(), snapshot_path: self.snapshot_path.clone(), tags: self.tags.clone(), @@ -522,6 +523,7 @@ impl Options { rejoin_after_leave: false, enable_id_conflict_resolution: true, disable_coordinates: false, + #[cfg(feature = "encryption")] keyring_file: None, max_user_event_size: 512, } diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index 622b7ad..15b8730 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -6,7 +6,7 @@ use crate::{ event::QueryMessageExt, types::{ DelegateVersion, JoinMessage, LamportTime, LeaveMessage, Member, MemberStatus, MessageRef, - ProtocolVersion, PushPullMessageBorrow, UserEventMessage, + ProtocolVersion, PushPullMessageBorrow, RelayMessageRef, UserEventMessage, }, }; @@ -260,11 +260,11 @@ where tracing::warn!(err=%e, "serf: failed to decode query response message"); } } - MessageRef::Relay { + MessageRef::Relay(RelayMessageRef { node, payload, payload_offset, - } => { + }) => { tracing::debug!("serf: relaying response to node: {:?}", node); match Data::from_ref(*node.address()) { Err(e) => { @@ -411,7 +411,7 @@ where Ok(buf) => { tracing::debug!(data=?buf.as_ref(), "serf: local state"); buf - }, + } Err(e) => { tracing::error!(err=%e, "serf: failed to encode local state"); Bytes::new() @@ -692,6 +692,7 @@ where if let Err(e) = coord.encode(&mut buf[1..]) { tracing::error!(err=%e, "serf: failed to encode coordinate"); } + tracing::trace!(coordinate=?coord, data=?buf.as_ref(), "serf: ack payload"); buf.into() } else { Bytes::new() @@ -709,11 +710,12 @@ where } let this = self.this(); + tracing::trace!(data=?payload.as_ref(), "serf: receive payload"); if let Some(ref c) = this.inner.coord_core { // Verify ping version in the header. if payload[0] != PING_VERSION { - tracing::error!("serf: unsupported ping version: {}", payload[0]); + tracing::error!(version = %payload[0], "serf: unsupported ping version"); return; } diff --git a/serf-core/src/types/arbitrary_impl.rs b/serf-core/src/types/arbitrary_impl.rs index 197556c..bc33614 100644 --- a/serf-core/src/types/arbitrary_impl.rs +++ b/serf-core/src/types/arbitrary_impl.rs @@ -165,3 +165,26 @@ impl<'a> Arbitrary<'a> for MessageType { .map(|val| Self::from(val % Self::ALL.len() as u8)) } } + +impl<'a> Arbitrary<'a> for super::coordinate::Coordinate { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + Ok(Self { + portion: Vec::::arbitrary(u)? + .into_iter() + .map(|f| if f.is_nan() { 0.0 } else { f }) + .collect(), + error: rand_f64_not_nan(u)?, + adjustment: rand_f64_not_nan(u)?, + height: rand_f64_not_nan(u)?, + }) + } +} + +fn rand_f64_not_nan(u: &mut Unstructured<'_>) -> arbitrary::Result { + loop { + let f = f64::arbitrary(u)?; + if !f.is_nan() { + return Ok(f); + } + } +} diff --git a/serf-core/src/types/conflict.rs b/serf-core/src/types/conflict.rs index 70d4036..1482c93 100644 --- a/serf-core/src/types/conflict.rs +++ b/serf-core/src/types/conflict.rs @@ -1,6 +1,6 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::*; @@ -123,14 +123,7 @@ where offset += len; member = Some(val); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("ConflictResponseMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("ConflictResponseMessage", &buf[offset..])?, } } diff --git a/serf-core/src/types/coordinate.rs b/serf-core/src/types/coordinate.rs index d765ad3..5a74993 100644 --- a/serf-core/src/types/coordinate.rs +++ b/serf-core/src/types/coordinate.rs @@ -2,7 +2,7 @@ use core::time::Duration; use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use rand::Rng; use smallvec::SmallVec; @@ -209,7 +209,6 @@ impl CoordinateOptions { #[viewit::viewit(getters(style = "move"), setters(prefix = "with"))] #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct Coordinate { /// The Euclidean portion of the coordinate. This is used along /// with the other fields to provide an overall distance estimate. The @@ -222,7 +221,6 @@ pub struct Coordinate { ), setter(attrs(doc = "Sets the Euclidean portion of the coordinate.")) )] - #[cfg_attr(feature = "arbitrary", arbitrary(with = crate::types::arbitrary_impl::into::, SmallVec<[f64; DEFAULT_DIMENSIONALITY]>>))] portion: SmallVec<[f64; DEFAULT_DIMENSIONALITY]>, /// Reflects the confidence in the given coordinate and is updated /// dynamically by the Vivaldi Client. This is dimensionless. @@ -435,7 +433,7 @@ const PORTION_TAG: u8 = 1; const ERROR_TAG: u8 = 2; const ADJUSTMENT_TAG: u8 = 3; const HEIGHT_TAG: u8 = 4; -const PORTION_BYTE: u8 = merge(WireType::LengthDelimited, PORTION_TAG); +const PORTION_BYTE: u8 = merge(WireType::Fixed64, PORTION_TAG); const ERROR_BYTE: u8 = merge(WireType::Fixed64, ERROR_TAG); const ADJUSTMENT_BYTE: u8 = merge(WireType::Fixed64, ADJUSTMENT_TAG); const HEIGHT_BYTE: u8 = merge(WireType::Fixed64, HEIGHT_TAG); @@ -465,23 +463,23 @@ impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { while offset < buf_len { match buf[offset] { - PORTION_TAG => { - let readed = skip(WireType::Fixed64, &buf[offset..])?; + PORTION_BYTE => { + let readed = skip("Coordinate", &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = portion_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - portion_offsets = Some((offset - 1, offset + readed)); + portion_offsets = Some((offset, offset + readed)); } num_portions += 1; offset += readed; } - ERROR_TAG => { + ERROR_BYTE => { if error.is_some() { return Err(DecodeError::duplicate_field( "Coordinate", @@ -489,12 +487,13 @@ impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { ERROR_TAG, )); } + offset += 1; let (len, val) = ::decode(&buf[offset..])?; offset += len; error = Some(val); } - ADJUSTMENT_TAG => { + ADJUSTMENT_BYTE => { if adjustment.is_some() { return Err(DecodeError::duplicate_field( "Coordinate", @@ -502,12 +501,13 @@ impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { ADJUSTMENT_TAG, )); } + offset += 1; let (len, val) = ::decode(&buf[offset..])?; offset += len; adjustment = Some(val); } - HEIGHT_TAG => { + HEIGHT_BYTE => { if height.is_some() { return Err(DecodeError::duplicate_field( "Coordinate", @@ -515,19 +515,13 @@ impl<'a> DataRef<'a, Coordinate> for CoordinateRef<'a> { HEIGHT_TAG, )); } + offset += 1; let (len, val) = ::decode(&buf[offset..])?; offset += len; height = Some(val); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("Coordinate", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("Coordinate", &buf[offset..])?, } } diff --git a/serf-core/src/types/filter.rs b/serf-core/src/types/filter.rs index 49e478d..bf258d2 100644 --- a/serf-core/src/types/filter.rs +++ b/serf-core/src/types/filter.rs @@ -1,6 +1,6 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TinyVec, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; pub use tag_filter::*; @@ -118,18 +118,17 @@ where while offset < buf_len { match buf[offset] { val if val == Filter::::id_byte() => { - offset += 1; - let readed = skip(I::WIRE_TYPE, &buf[offset..])?; + let readed = skip("Filter", &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = ids_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - ids_offsets = Some((offset - 1, offset + readed)); + ids_offsets = Some((offset, offset + readed)); } num_ids += 1; offset += readed; @@ -153,13 +152,7 @@ where offset += read; f = Some(FilterRef::Tag(tag)); } - b => { - let (wire_type, _) = split(b); - let wt = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("Filter", v))?; - offset += 1; - offset += skip(wt, &buf[offset..])?; - } + _ => offset += skip("Filter", &buf[offset..])?, } } diff --git a/serf-core/src/types/filter/tag_filter.rs b/serf-core/src/types/filter/tag_filter.rs index eebbb7a..7d69dcd 100644 --- a/serf-core/src/types/filter/tag_filter.rs +++ b/serf-core/src/types/filter/tag_filter.rs @@ -1,6 +1,6 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use regex::Regex; use smol_str::SmolStr; @@ -54,14 +54,7 @@ impl<'a> DataRef<'a, TagFilter> for TagFilterRef<'a> { offset += read; expr = Some(value); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("TagFilter", v))?; - offset += skip(wire_type, &src[offset..])?; - } + _ => offset += skip("TagFilter", &src[offset..])?, } } diff --git a/serf-core/src/types/join.rs b/serf-core/src/types/join.rs index c8a38a1..c596ac6 100644 --- a/serf-core/src/types/join.rs +++ b/serf-core/src/types/join.rs @@ -1,6 +1,6 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::LamportTime; @@ -86,14 +86,7 @@ where offset += read; id = Some(value); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("JoinMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("JoinMessage", &buf[offset..])?, } } diff --git a/serf-core/src/types/key.rs b/serf-core/src/types/key.rs index 65621bb..040d70f 100644 --- a/serf-core/src/types/key.rs +++ b/serf-core/src/types/key.rs @@ -1,7 +1,7 @@ use indexmap::IndexMap; use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, SecretKey, SecretKeys, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use smol_str::SmolStr; @@ -44,14 +44,7 @@ impl DataRef<'_, Self> for KeyRequestMessage { offset += bytes_read; key = Some(val); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("KeyRequestMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("KeyRequestMessage", &buf[offset..])?, } } @@ -231,18 +224,17 @@ impl<'a> DataRef<'a, KeyResponseMessage> for KeyResponseMessageRef<'a> { message = Some(val); } KEY_RESPONSE_KEYS_BYTE => { - offset += 1; - let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + let readed = skip("KeyResponseMessage", &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = keys_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - keys_offsets = Some((offset - 1, offset + readed)); + keys_offsets = Some((offset, offset + readed)); } num_keys += 1; offset += readed; @@ -261,13 +253,7 @@ impl<'a> DataRef<'a, KeyResponseMessage> for KeyResponseMessageRef<'a> { offset += bytes_read; primary_key = Some(val); } - other => { - offset += 1; - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("KeyResponseMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("KeyResponseMessage", &buf[offset..])?, } } diff --git a/serf-core/src/types/leave.rs b/serf-core/src/types/leave.rs index d14cd34..ff721a4 100644 --- a/serf-core/src/types/leave.rs +++ b/serf-core/src/types/leave.rs @@ -1,6 +1,6 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::LamportTime; @@ -103,14 +103,7 @@ where offset += read; id = Some(id_ref); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("LeaveMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("LeaveMessage", &buf[offset..])?, } } diff --git a/serf-core/src/types/member.rs b/serf-core/src/types/member.rs index 09edf68..44ac15f 100644 --- a/serf-core/src/types/member.rs +++ b/serf-core/src/types/member.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use memberlist_core::proto::{ CheapClone, Data, DataRef, DecodeError, EncodeError, OneOrMore, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::{ @@ -372,14 +372,7 @@ where delegate_version = Some(buf[offset].into()); offset += 1; } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("Member", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("Member", &buf[offset..])?, } } diff --git a/serf-core/src/types/message.rs b/serf-core/src/types/message.rs index a23c58f..1d5813f 100644 --- a/serf-core/src/types/message.rs +++ b/serf-core/src/types/message.rs @@ -1,7 +1,7 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, Node, WireType, bytes::Bytes, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::{ @@ -299,7 +299,29 @@ where } } +/// A reference type to a relay message. +#[viewit::viewit(vis_all = "pub(crate)", getters(vis_all = "pub"), setters(skip))] +#[derive(Debug, Clone, Copy)] +pub struct RelayMessageRef<'a, I, A> { + /// The node + #[viewit(getter(style = "ref", attrs(doc = "Get the node to relay to")))] + node: Node, + /// The offset of the payload to the original buffer + #[viewit(getter( + style = "move", + attrs(doc = "Get the offset of the payload to the original buffer") + ))] + payload_offset: usize, + /// The relay message payload + #[viewit(getter(style = "move", attrs(doc = "Get the relay message payload")))] + payload: &'a [u8], +} + /// A reference to a message. +#[derive(Debug, derive_more::IsVariant, derive_more::Unwrap, derive_more::TryUnwrap)] +#[unwrap(ref)] +#[try_unwrap(ref)] +#[non_exhaustive] pub enum MessageRef<'a, I, A> { /// Leave message Leave(LeaveMessage), @@ -316,14 +338,7 @@ pub enum MessageRef<'a, I, A> { /// ConflictResponse message ConflictResponse(ConflictResponseMessageRef<'a, I, A>), /// Relay message - Relay { - /// The node - node: Node, - /// The offset of the payload to the original buffer - payload_offset: usize, - /// The relay message payload - payload: &'a [u8], - }, + Relay(RelayMessageRef<'a, I, A>), #[cfg(feature = "encryption")] /// KeyRequest message KeyRequest(KeyRequestMessage), @@ -626,11 +641,11 @@ where offset += 1; let (readed, (node, payload)) = decode_relay::(&buf[offset..])?; offset += readed; - msg = Some(MessageRef::Relay { + msg = Some(MessageRef::Relay(RelayMessageRef { node, payload, payload_offset: offset - payload.len(), - }); + })); } #[cfg(feature = "encryption")] KEY_REQUEST_MESSAGE_BYTE => { @@ -668,14 +683,7 @@ where offset += len; msg = Some(MessageRef::KeyResponse(val)); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("Message", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("Message", &buf[offset..])?, } } @@ -739,14 +747,7 @@ where offset += length as usize; msg = Some(&buf[start_offset..offset]); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("RelayMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("RelayMessage", &buf[offset..])?, } } diff --git a/serf-core/src/types/push_pull.rs b/serf-core/src/types/push_pull.rs index d570d3c..ef649dc 100644 --- a/serf-core/src/types/push_pull.rs +++ b/serf-core/src/types/push_pull.rs @@ -1,7 +1,7 @@ use indexmap::{IndexMap, IndexSet}; use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TinyVec, TupleEncoder, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::{LamportTime, UserEvents}; @@ -202,35 +202,33 @@ where ltime = Some(v); } STATUS_LTIMES_BYTE => { - offset += 1; - let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + let readed = skip("PushPull", &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = status_ltimes_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - status_ltimes_offsets = Some((offset - 1, offset + readed)); + status_ltimes_offsets = Some((offset, offset + readed)); } num_status_ltimes += 1; offset += readed; } b if b == left_members_byte => { - offset += 1; - let readed = skip(I::WIRE_TYPE, &buf[offset..])?; + let readed = skip("PushPull", &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = left_members_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - left_members_offsets = Some((offset - 1, offset + readed)); + left_members_offsets = Some((offset, offset + readed)); } num_left_members += 1; offset += readed; @@ -250,18 +248,17 @@ where event_ltime = Some(v); } EVENTS_BYTE => { - offset += 1; - let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + let readed = skip("PushPull", &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = events_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - events_offsets = Some((offset - 1, offset + readed)); + events_offsets = Some((offset, offset + readed)); } num_events += 1; offset += readed; @@ -280,14 +277,7 @@ where offset += o; query_ltime = Some(v); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("PushPullMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("PushPull", &buf[offset..])?, } } @@ -509,13 +499,13 @@ where .sum::(); len += 1 + self.event_ltime.encoded_len(); len += self - .events - .iter() - .filter_map(|e| { - e.as_ref() - .map(|e| 1 + e.encoded_len_with_length_delimited()) - }) - .sum::(); + .events + .iter() + .filter_map(|e| { + e.as_ref() + .map(|e| 1 + e.encoded_len_with_length_delimited()) + }) + .sum::(); len += 1 + self.query_ltime.encoded_len(); len diff --git a/serf-core/src/types/query.rs b/serf-core/src/types/query.rs index 76134a5..3209612 100644 --- a/serf-core/src/types/query.rs +++ b/serf-core/src/types/query.rs @@ -5,7 +5,7 @@ use std::time::Duration; use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, Node, RepeatedDecoder, TinyVec, WireType, bytes::Bytes, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::{Filter, LamportTime}; @@ -256,18 +256,17 @@ where from = Some(v); } FILTERS_BYTE => { - offset += 1; - let readed = skip(WireType::LengthDelimited, &buf[offset..])?; + let readed = skip("QueryMessage", &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = filters_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - filters_offsets = Some((offset - 1, offset + readed)); + filters_offsets = Some((offset, offset + readed)); } num_filters += 1; offset += readed; @@ -342,14 +341,7 @@ where offset += o; payload = Some(v); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("QueryMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("QueryMessage", &buf[offset..])?, } } diff --git a/serf-core/src/types/query/response.rs b/serf-core/src/types/query/response.rs index b628d38..e79364b 100644 --- a/serf-core/src/types/query/response.rs +++ b/serf-core/src/types/query/response.rs @@ -1,7 +1,7 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, Node, WireType, bytes::Bytes, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::{LamportTime, QueryFlag}; @@ -191,14 +191,7 @@ where offset += o; payload = Some(v); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("QueryResponseMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("QueryResponseMessage", &buf[offset..])?, } } diff --git a/serf-core/src/types/quickcheck_impl.rs b/serf-core/src/types/quickcheck_impl.rs index 953131a..61faba0 100644 --- a/serf-core/src/types/quickcheck_impl.rs +++ b/serf-core/src/types/quickcheck_impl.rs @@ -303,10 +303,22 @@ impl Arbitrary for MessageType { impl Arbitrary for Coordinate { fn arbitrary(g: &mut Gen) -> Self { Self { - portion: Vec::arbitrary(g).into(), - error: Arbitrary::arbitrary(g), - adjustment: Arbitrary::arbitrary(g), - height: Arbitrary::arbitrary(g), + portion: Vec::::arbitrary(g) + .into_iter() + .map(|f| if f.is_nan() { 0.0 } else { f }) + .collect(), + error: rand_f64_not_nan(g), + adjustment: rand_f64_not_nan(g), + height: rand_f64_not_nan(g), + } + } +} + +fn rand_f64_not_nan(u: &mut Gen) -> f64 { + loop { + let f = f64::arbitrary(u); + if !f.is_nan() { + return f; } } } diff --git a/serf-core/src/types/tags.rs b/serf-core/src/types/tags.rs index 954c936..31f5c67 100644 --- a/serf-core/src/types/tags.rs +++ b/serf-core/src/types/tags.rs @@ -1,7 +1,7 @@ use indexmap::IndexMap; use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, RepeatedDecoder, TupleEncoder, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use smol_str::SmolStr; @@ -78,31 +78,22 @@ impl<'a> DataRef<'a, Tags> for TagsRef<'a> { while offset < buf_len { match src[offset] { TAGS_BYTE => { - offset += 1; - - let readed = skip(WireType::LengthDelimited, &src[offset..])?; + let readed = skip("Tags", &src[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = tags_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - tags_offsets = Some((offset - 1, offset + readed)); + tags_offsets = Some((offset, offset + readed)); } num_tags += 1; offset += readed; } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = - WireType::try_from(wire_type).map_err(|v| DecodeError::unknown_wire_type("Tags", v))?; - offset += skip(wire_type, &src[offset..])?; - } + _ => offset += skip("Tags", &src[offset..])?, } } diff --git a/serf-core/src/types/tests.rs b/serf-core/src/types/tests.rs index 8401d25..6f70106 100644 --- a/serf-core/src/types/tests.rs +++ b/serf-core/src/types/tests.rs @@ -6,7 +6,7 @@ use memberlist_core::{ }; use quickcheck::{Arbitrary, Gen}; -use super::*; +use super::{coordinate::Coordinate, *}; fn data_round_trip(data: &T) { let mut buf = vec![0; data.encoded_len() + 2]; @@ -74,6 +74,11 @@ type QueryResponseMessageU64U64 = QueryResponseMessage; type QueryResponseMessageStringU64 = QueryResponseMessage; data_round_trip! { + Coordinate, +} + +data_round_trip! { + // Coordinate, ConflictResponseMessageStringString, ConflictResponseMessageU64U64, ConflictResponseMessageStringU64, @@ -193,7 +198,7 @@ where let data = encode_relay(&$input, &$node); assert_eq!(data.len(), encoded_relay_message_len(&$input, &$node), "relay message: length mismatch"); let decoded = super::decode_message :: < $($g),* > (&data).unwrap(); - let MessageRef::Relay { node, payload, .. } = decoded else { return false }; + let MessageRef::Relay(RelayMessageRef { node, payload, .. }) = decoded else { return false }; assert_eq!( as Data>::from_ref(node).unwrap(), $node, "relay message: node mismatch"); let decoded = super::decode_message :: < $($g),* > (&payload).unwrap(); @@ -468,8 +473,22 @@ encodable_round_trip!( #[test] fn test() { - let data = [19, 33, 9, 1, 18, 18, 24, 17, 19, 115, 101, 114, 102, 95, 106, 111, 105, 110, 95, 108, 101, 97, 118, 101, 49, 95, 118, 52, 18, 1, 0, 12, 1, 14, 1]; + let data = [ + 19, 33, 9, 1, 18, 25, 17, 20, 115, 101, 114, 102, 95, 99, 111, 111, 114, 100, 105, 110, 97, + 116, 101, 115, 50, 95, 118, 52, 18, 1, 0, 12, 1, 14, 1, + ]; let msg = super::decode_message::(&data).unwrap(); - // println!("{:?}", msg); + let pp = msg.unwrap_push_pull(); + let msg = PushPullMessage::::from_ref(pp).unwrap(); + println!("{:?}", msg); +} + +#[test] +fn test_3() { + let coord = coordinate::Coordinate::new(); + let data1 = coord.encode_to_vec().unwrap(); + println!("{:?}", &data1); + let (_, coord) = ::decode(&data1).unwrap(); + println!("{:?}", coord); } diff --git a/serf-core/src/types/user_event.rs b/serf-core/src/types/user_event.rs index 2edc859..6a38e02 100644 --- a/serf-core/src/types/user_event.rs +++ b/serf-core/src/types/user_event.rs @@ -1,7 +1,7 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, WireType, bytes::Bytes, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use smol_str::SmolStr; @@ -84,14 +84,7 @@ impl<'a> DataRef<'a, UserEvent> for UserEventRef<'a> { payload = Some(val); offset += size; } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("UserEvent", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("UserEvent", &buf[offset..])?, } } diff --git a/serf-core/src/types/user_event/message.rs b/serf-core/src/types/user_event/message.rs index 5b50b02..85b2e95 100644 --- a/serf-core/src/types/user_event/message.rs +++ b/serf-core/src/types/user_event/message.rs @@ -1,7 +1,7 @@ use memberlist_core::proto::{ CheapClone, Data, DataRef, DecodeError, EncodeError, WireType, bytes::Bytes, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use smol_str::SmolStr; @@ -171,14 +171,7 @@ impl<'a> DataRef<'a, UserEventMessage> for UserEventMessageRef<'a> { offset += o; payload = Some(v); } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("UserEventMessage", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("UserEventMessage", &buf[offset..])?, } } diff --git a/serf-core/src/types/user_event/user_events.rs b/serf-core/src/types/user_event/user_events.rs index 34f0577..c1cd05c 100644 --- a/serf-core/src/types/user_event/user_events.rs +++ b/serf-core/src/types/user_event/user_events.rs @@ -1,6 +1,6 @@ use memberlist_core::proto::{ Data, DataRef, DecodeError, EncodeError, OneOrMore, RepeatedDecoder, WireType, - utils::{merge, skip, split}, + utils::{merge, skip}, }; use super::{super::LamportTime, UserEvent}; @@ -77,31 +77,22 @@ impl<'a> DataRef<'a, UserEvents> for UserEventsRef<'a> { offset += size; } EVENTS_BYTE => { - offset += 1; - - let readed = super::skip(WireType::LengthDelimited, &buf[offset..])?; + let readed = super::skip("UserEvents", &buf[offset..])?; if let Some((ref mut fnso, ref mut lnso)) = events_offsets { if *fnso > offset { - *fnso = offset - 1; + *fnso = offset; } if *lnso < offset + readed { *lnso = offset + readed; } } else { - events_offsets = Some((offset - 1, offset + readed)); + events_offsets = Some((offset, offset + readed)); } num_events += 1; offset += readed; } - other => { - offset += 1; - - let (wire_type, _) = split(other); - let wire_type = WireType::try_from(wire_type) - .map_err(|v| DecodeError::unknown_wire_type("UserEvents", v))?; - offset += skip(wire_type, &buf[offset..])?; - } + _ => offset += skip("UserEvents", &buf[offset..])?, } } diff --git a/serf/test/main/net/coordinates.rs b/serf/test/main/net/coordinates.rs index a690850..15a8381 100644 --- a/serf/test/main/net/coordinates.rs +++ b/serf/test/main/net/coordinates.rs @@ -39,30 +39,30 @@ macro_rules! test_mod { >(opts, opts2, opts3)); } - #[test] - fn test_serf_coordinates_v6() { - let name = "serf_coordinates1_v6"; - let mut opts = NetTransportOptions::new(SmolStr::new(name)); - opts.add_bind_address(next_socket_addr_v6()); + // #[test] + // fn test_serf_coordinates_v6() { + // let name = "serf_coordinates1_v6"; + // let mut opts = NetTransportOptions::new(SmolStr::new(name)); + // opts.add_bind_address(next_socket_addr_v6()); - let name = "serf_coordinates2_v6"; - let mut opts2 = NetTransportOptions::new(SmolStr::new(name)); - opts2.add_bind_address(next_socket_addr_v6()); + // let name = "serf_coordinates2_v6"; + // let mut opts2 = NetTransportOptions::new(SmolStr::new(name)); + // opts2.add_bind_address(next_socket_addr_v6()); - let name = "serf_coordinates3_v6"; - let mut opts3 = NetTransportOptions::new(SmolStr::new(name)); - opts3.add_bind_address(next_socket_addr_v6()); + // let name = "serf_coordinates3_v6"; + // let mut opts3 = NetTransportOptions::new(SmolStr::new(name)); + // opts3.add_bind_address(next_socket_addr_v6()); - [< $rt:snake _run >](serf_coordinates::< - NetTransport< - SmolStr, - SocketAddrResolver<[< $rt:camel Runtime >]>, - Tcp<[< $rt:camel Runtime >]>, + // [< $rt:snake _run >](serf_coordinates::< + // NetTransport< + // SmolStr, + // SocketAddrResolver<[< $rt:camel Runtime >]>, + // Tcp<[< $rt:camel Runtime >]>, - [< $rt:camel Runtime >], - >, - >(opts, opts2, opts3)); - } + // [< $rt:camel Runtime >], + // >, + // >(opts, opts2, opts3)); + // } } } }; From ad1ada2d4942e6280feb8b71185375ada15afd4d Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 16:08:39 +0800 Subject: [PATCH 18/39] Fix all unit tests --- .../src/serf/base/tests/serf/delegate.rs | 3 -- serf-core/src/serf/delegate.rs | 7 +--- serf-core/src/serf/internal_query.rs | 6 ++- serf-core/src/types/join.rs | 6 ++- serf/test/main/net/coordinates.rs | 40 +++++++++---------- serf/test/main/net/delegate/local_state.rs | 1 - 6 files changed, 31 insertions(+), 32 deletions(-) diff --git a/serf-core/src/serf/base/tests/serf/delegate.rs b/serf-core/src/serf/base/tests/serf/delegate.rs index 0058591..c8a45be 100644 --- a/serf-core/src/serf/base/tests/serf/delegate.rs +++ b/serf-core/src/serf/base/tests/serf/delegate.rs @@ -77,9 +77,6 @@ where .local_state(false) .await; - // Verify - assert_eq!(buf[0], u8::from(MessageType::PushPull), "bad message type"); - // Attempt a decode let pp = crate::types::decode_message::(&buf).unwrap(); diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index 15b8730..2527a51 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -200,7 +200,6 @@ where match msg { MessageRef::Leave(l) => { tracing::debug!("serf: leave message: {:?}", l.id()); - // TODO(al8n): do not read to owned here match as Data>::from_ref(l) { Err(e) => { tracing::error!(err=%e, "serf: failed to decode leave message"); @@ -212,8 +211,6 @@ where } MessageRef::Join(j) => { tracing::debug!("serf: join message: {:?}", j.id()); - // TODO(al8n): do not read to owned here - match as Data>::from_ref(j) { Err(e) => { tracing::error!(err=%e, "serf: failed to decode join message"); @@ -409,7 +406,7 @@ where match crate::types::encode_message_to_bytes(&pp) { Ok(buf) => { - tracing::debug!(data=?buf.as_ref(), "serf: local state"); + tracing::trace!(data=?buf.as_ref(), "serf: local state"); buf } Err(e) => { @@ -425,7 +422,7 @@ where return; } - tracing::debug!(data=?buf, "serf: merge remote state"); + tracing::trace!(data=?buf, "serf: merge remote state"); // Check the message type let msg = match crate::types::decode_message::(buf) { diff --git a/serf-core/src/serf/internal_query.rs b/serf-core/src/serf/internal_query.rs index b8ec23f..537c022 100644 --- a/serf-core/src/serf/internal_query.rs +++ b/serf-core/src/serf/internal_query.rs @@ -10,11 +10,13 @@ use memberlist_core::{ use crate::{ delegate::Delegate, event::{CrateEvent, InternalQueryEvent, QueryEvent}, - types::MessageRef, }; #[cfg(feature = "encryption")] -use crate::{error::Error, types::KeyResponseMessage}; +use crate::{ + error::Error, + types::{KeyResponseMessage, MessageRef}, +}; #[cfg(feature = "encryption")] use smol_str::SmolStr; diff --git a/serf-core/src/types/join.rs b/serf-core/src/types/join.rs index c596ac6..f71e055 100644 --- a/serf-core/src/types/join.rs +++ b/serf-core/src/types/join.rs @@ -18,7 +18,11 @@ const ID_TAG: u8 = 2; pub struct JoinMessage { /// The lamport time #[viewit( - getter(const, attrs(doc = "Returns the lamport time for this message")), + getter( + const, + style = "move", + attrs(doc = "Returns the lamport time for this message") + ), setter( const, attrs(doc = "Sets the lamport time for this message (Builder pattern)") diff --git a/serf/test/main/net/coordinates.rs b/serf/test/main/net/coordinates.rs index 15a8381..a690850 100644 --- a/serf/test/main/net/coordinates.rs +++ b/serf/test/main/net/coordinates.rs @@ -39,30 +39,30 @@ macro_rules! test_mod { >(opts, opts2, opts3)); } - // #[test] - // fn test_serf_coordinates_v6() { - // let name = "serf_coordinates1_v6"; - // let mut opts = NetTransportOptions::new(SmolStr::new(name)); - // opts.add_bind_address(next_socket_addr_v6()); + #[test] + fn test_serf_coordinates_v6() { + let name = "serf_coordinates1_v6"; + let mut opts = NetTransportOptions::new(SmolStr::new(name)); + opts.add_bind_address(next_socket_addr_v6()); - // let name = "serf_coordinates2_v6"; - // let mut opts2 = NetTransportOptions::new(SmolStr::new(name)); - // opts2.add_bind_address(next_socket_addr_v6()); + let name = "serf_coordinates2_v6"; + let mut opts2 = NetTransportOptions::new(SmolStr::new(name)); + opts2.add_bind_address(next_socket_addr_v6()); - // let name = "serf_coordinates3_v6"; - // let mut opts3 = NetTransportOptions::new(SmolStr::new(name)); - // opts3.add_bind_address(next_socket_addr_v6()); + let name = "serf_coordinates3_v6"; + let mut opts3 = NetTransportOptions::new(SmolStr::new(name)); + opts3.add_bind_address(next_socket_addr_v6()); - // [< $rt:snake _run >](serf_coordinates::< - // NetTransport< - // SmolStr, - // SocketAddrResolver<[< $rt:camel Runtime >]>, - // Tcp<[< $rt:camel Runtime >]>, + [< $rt:snake _run >](serf_coordinates::< + NetTransport< + SmolStr, + SocketAddrResolver<[< $rt:camel Runtime >]>, + Tcp<[< $rt:camel Runtime >]>, - // [< $rt:camel Runtime >], - // >, - // >(opts, opts2, opts3)); - // } + [< $rt:camel Runtime >], + >, + >(opts, opts2, opts3)); + } } } }; diff --git a/serf/test/main/net/delegate/local_state.rs b/serf/test/main/net/delegate/local_state.rs index a0455cc..1bbaf01 100644 --- a/serf/test/main/net/delegate/local_state.rs +++ b/serf/test/main/net/delegate/local_state.rs @@ -29,7 +29,6 @@ macro_rules! test_mod { SmolStr, SocketAddrResolver<[< $rt:camel Runtime >]>, Tcp<[< $rt:camel Runtime >]>, - [< $rt:camel Runtime >], >, >(opts, opts2)); From 9cd4cc86d9b3f985b90a1490341ea1667af6f04b Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 17:45:00 +0800 Subject: [PATCH 19/39] Update --- Cargo.toml | 17 +++++++++++------ serf-core/src/types/tests.rs | 22 ---------------------- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 252b999..3bb0af0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,9 +24,6 @@ futures = { version = "0.3", default-features = false } serde = { version = "1", features = ["derive"] } humantime-serde = "1" indexmap = "2" -# memberlist-proto = { version = "0.3", default-features = false } -# memberlist-core = { version = "0.3", default-features = false } -# memberlist = { version = "0.3", default-features = false } thiserror = { version = "2", default-features = false } viewit = "0.1.5" regex = "1" @@ -37,8 +34,16 @@ rand = "0.9" arbitrary = { version = "1", default-features = false, features = ["derive"] } quickcheck = { version = "1", default-features = false } -memberlist-proto = { version = "0.1", path = "../memberlist/memberlist-proto", default-features = false } -memberlist-core = { version = "0.6", path = "../memberlist/memberlist-core", default-features = false } -memberlist = { version = "0.6", path = "../memberlist/memberlist", default-features = false } +# memberlist-proto = { version = "0.1", path = "../memberlist/memberlist-proto", default-features = false } +# memberlist-core = { version = "0.6", path = "../memberlist/memberlist-core", default-features = false } +# memberlist = { version = "0.6", path = "../memberlist/memberlist", default-features = false } + +memberlist-proto = { version = "0.1", default-features = false, git = "https://github.com/al8n/memberlist" } +memberlist-core = { version = "0.6", default-features = false, git = "https://github.com/al8n/memberlist" } +memberlist = { version = "0.6", default-features = false, git = "https://github.com/al8n/memberlist" } + +# memberlist-proto = { version = "0.1", default-features = false } +# memberlist-core = { version = "0.6", default-features = false } +# memberlist = { version = "0.6", default-features = false } serf-core = { path = "serf-core", version = "0.3.0", default-features = false } diff --git a/serf-core/src/types/tests.rs b/serf-core/src/types/tests.rs index 6f70106..25253fe 100644 --- a/serf-core/src/types/tests.rs +++ b/serf-core/src/types/tests.rs @@ -470,25 +470,3 @@ encodable_round_trip!( , , ); - -#[test] -fn test() { - let data = [ - 19, 33, 9, 1, 18, 25, 17, 20, 115, 101, 114, 102, 95, 99, 111, 111, 114, 100, 105, 110, 97, - 116, 101, 115, 50, 95, 118, 52, 18, 1, 0, 12, 1, 14, 1, - ]; - - let msg = super::decode_message::(&data).unwrap(); - let pp = msg.unwrap_push_pull(); - let msg = PushPullMessage::::from_ref(pp).unwrap(); - println!("{:?}", msg); -} - -#[test] -fn test_3() { - let coord = coordinate::Coordinate::new(); - let data1 = coord.encode_to_vec().unwrap(); - println!("{:?}", &data1); - let (_, coord) = ::decode(&data1).unwrap(); - println!("{:?}", coord); -} From 061f642ebde0aa7b28993b0249e03b9b7407db7f Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 17:46:45 +0800 Subject: [PATCH 20/39] Update write_keyring_file.rs --- serf/test/main/net/write_keyring_file.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/serf/test/main/net/write_keyring_file.rs b/serf/test/main/net/write_keyring_file.rs index dea2559..f397474 100644 --- a/serf/test/main/net/write_keyring_file.rs +++ b/serf/test/main/net/write_keyring_file.rs @@ -29,7 +29,7 @@ macro_rules! test_mod { [< $rt:camel Runtime >], >, >(|kr| { - (opts, MemberlistOptions::lan().with_primary_key(Some(kr)).with_gossip_verify_outgoing(true).with_encryption_algo(Some(EncryptionAlgorithm::default()))) + (opts, MemberlistOptions::lan().with_primary_key(kr).with_gossip_verify_outgoing(true).with_encryption_algo(EncryptionAlgorithm::default())) })); } @@ -48,7 +48,7 @@ macro_rules! test_mod { [< $rt:camel Runtime >], >, >(|kr| { - (opts, MemberlistOptions::lan().with_primary_key(Some(kr)).with_gossip_verify_outgoing(true).with_encryption_algo(Some(EncryptionAlgorithm::default()))) + (opts, MemberlistOptions::lan().with_primary_key(kr).with_gossip_verify_outgoing(true).with_encryption_algo(EncryptionAlgorithm::default())) })); } } From 871435b8fd8e8b77125c5f7ef478ad7962fd81a1 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 17:49:13 +0800 Subject: [PATCH 21/39] Update CI --- .github/workflows/ci.yml | 7 +------ .github/workflows/coverage.yml | 3 --- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8c3f7cb..9816d30 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,12 +68,7 @@ jobs: - name: Run Unit Tests for core run: | cargo test --no-default-features --features "test,encryption,serde" - working-directory: core - - - name: Run Unit Tests for types - run: | - cargo test --no-default-features --features "metrics,encryption,serde" - working-directory: types + working-directory: serf-core - name: Cache Cargo registry uses: actions/cache@v4 diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 78662dc..865581a 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -34,9 +34,6 @@ jobs: - crate: core features: "test,metrics" name: "serf-core" - - crate: types - features: "metrics,encryption" - name: "serf-proto" - crate: serf features: "test,tokio,tcp,encryption,metrics" name: "serf-tcp-encryption" From dae8c832173cc69ba9dbdabbccea524b4dbf9aa6 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 17:57:15 +0800 Subject: [PATCH 22/39] Update CI --- .github/workflows/ci.yml | 2 +- .github/workflows/coverage.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9816d30..d36f4a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,7 +67,7 @@ jobs: - name: Run Unit Tests for core run: | - cargo test --no-default-features --features "test,encryption,serde" + cargo test --no-default-features --features "test,encryption,serde,quickcheck" working-directory: serf-core - name: Cache Cargo registry diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 865581a..b9bd549 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -32,7 +32,7 @@ jobs: matrix: include: - crate: core - features: "test,metrics" + features: "test,metrics,quickcheck,encryption" name: "serf-core" - crate: serf features: "test,tokio,tcp,encryption,metrics" From 3243dd96a23473a357171d084efb037a3be6d3a8 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 20:19:09 +0800 Subject: [PATCH 23/39] Fix the write_keyring_file --- serf-core/Cargo.toml | 6 ++---- serf-core/src/key_manager.rs | 30 +++++++++++++-------------- serf-core/src/serf/base.rs | 4 +--- serf-core/src/serf/base/tests/serf.rs | 8 ++----- serf-core/src/serf/delegate.rs | 2 +- serf-core/src/serf/internal_query.rs | 2 +- 6 files changed, 22 insertions(+), 30 deletions(-) diff --git a/serf-core/Cargo.toml b/serf-core/Cargo.toml index a56c69f..4f05d92 100644 --- a/serf-core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -15,7 +15,7 @@ categories.workspace = true default = ["metrics"] metrics = ["memberlist-core/metrics", "dep:metrics"] -encryption = ["memberlist-core/encryption", "base64", "serde"] +encryption = ["memberlist-core/encryption", "serde", "serde_json"] crc32 = ["memberlist-core/crc32"] murmur3 = ["memberlist-core/murmur3"] @@ -73,9 +73,7 @@ metrics = { version = "0.24", optional = true } serde = { workspace = true, optional = true } humantime-serde = { workspace = true, optional = true } -serde_json = "1" - -base64 = { version = "0.22", optional = true } +serde_json = { version = "1", optional = true } arbitrary = { workspace = true, optional = true, default-features = false, features = ["derive"] } quickcheck = { workspace = true, optional = true, default-features = false } diff --git a/serf-core/src/key_manager.rs b/serf-core/src/key_manager.rs index 51683f6..c10070b 100644 --- a/serf-core/src/key_manager.rs +++ b/serf-core/src/key_manager.rs @@ -17,7 +17,7 @@ use super::{ delegate::Delegate, error::Error, serf::{NodeResponse, QueryResponse}, - types::{KeyRequestMessage, MessageType}, + types::KeyRequestMessage, }; /// KeyResponse is used to relay a query for a list of all keys in use. @@ -232,14 +232,10 @@ where resp.num_resp += 1; // Decode the response - if r.payload.is_empty() || r.payload[0] != u8::from(MessageType::KeyResponse) { - resp.messages.insert( - r.from.id().cheap_clone(), - SmolStr::new(format!( - "Invalid key query response type: {:?}", - r.payload.as_ref() - )), - ); + if r.payload.is_empty() { + resp + .messages + .insert(r.from.id().cheap_clone(), SmolStr::new("empty payload")); resp.num_err += 1; if resp.num_resp == resp.num_nodes { @@ -253,9 +249,11 @@ where Ok(msg) => match msg { MessageRef::KeyResponse(kr) => kr, msg => { + tracing::error!(type=%msg.ty(), "serf: invalid key query response type"); + resp.messages.insert( r.from.id().cheap_clone(), - format_smolstr!("Invalid key query response type: {}", msg.ty()), + format_smolstr!("invalid key query response: {}", msg.ty()), ); resp.num_err += 1; @@ -266,10 +264,10 @@ where } }, Err(e) => { - resp.messages.insert( - r.from.id().cheap_clone(), - SmolStr::new(format!("Failed to decode key query response: {:?}", e)), - ); + tracing::error!(err=%e, "serf: failed to decode key query response"); + resp + .messages + .insert(r.from.id().cheap_clone(), format_smolstr!("{e}")); resp.num_err += 1; if resp.num_resp == resp.num_nodes { @@ -285,7 +283,9 @@ where SmolStr::new(node_response.message()), ); resp.num_err += 1; - } else if node_response.result() && node_response.message().is_empty() { + } + + if node_response.result() && !node_response.message().is_empty() { tracing::warn!("serf: {}", node_response.message()); resp.messages.insert( r.from.id().cheap_clone(), diff --git a/serf-core/src/serf/base.rs b/serf-core/src/serf/base.rs index fc910e7..2965bb4 100644 --- a/serf-core/src/serf/base.rs +++ b/serf-core/src/serf/base.rs @@ -399,8 +399,6 @@ where /// Serialize the current keyring and save it to a file. #[cfg(feature = "encryption")] pub(crate) async fn write_keyring_file(&self) -> std::io::Result<()> { - use base64::{Engine as _, engine::general_purpose}; - let Some(path) = self.inner.opts.keyring_file() else { return Ok(()); }; @@ -408,7 +406,7 @@ where if let Some(keyring) = self.inner.memberlist.keyring() { let encoded_keys = keyring .keys() - .map(|k| general_purpose::STANDARD.encode(k)) + .map(|k| k.to_base64()) .collect::>(); #[cfg(unix)] diff --git a/serf-core/src/serf/base/tests/serf.rs b/serf-core/src/serf/base/tests/serf.rs index 7039b20..209f6bc 100644 --- a/serf-core/src/serf/base/tests/serf.rs +++ b/serf-core/src/serf/base/tests/serf.rs @@ -790,8 +790,6 @@ pub async fn serf_write_keyring_file( { use std::io::Read; - use base64::{Engine as _, engine::general_purpose}; - const EXISTING: &str = "T9jncgl9mbLus+baTTa7q7nPSUrXwbDi2dhbtqir37s="; const NEW_KEY: &str = "HvY8ubRZMgafUOWvrOadwOckVa1wN3QWAo46FVKbVN8="; @@ -799,8 +797,7 @@ pub async fn serf_write_keyring_file( let mut p = td.path().join("serf_write_keying_file"); p.set_extension("json"); - let existing_bytes = general_purpose::STANDARD.decode(EXISTING).unwrap(); - let sk = memberlist_core::proto::SecretKey::try_from(existing_bytes.as_slice()).unwrap(); + let sk = crate::types::SecretKey::try_from(EXISTING).unwrap(); let (topts, mopts) = get_transport_opts(sk); let serf = Serf::::new( @@ -817,8 +814,7 @@ pub async fn serf_write_keyring_file( ); let manager = serf.key_manager(); - let new_key = general_purpose::STANDARD.decode(NEW_KEY).unwrap(); - let new_sk = memberlist_core::proto::SecretKey::try_from(new_key.as_slice()).unwrap(); + let new_sk = crate::types::SecretKey::try_from(NEW_KEY).unwrap(); manager.install_key(new_sk, None).await.unwrap(); let mut keyring_file = std::fs::File::open(&p).unwrap(); diff --git a/serf-core/src/serf/delegate.rs b/serf-core/src/serf/delegate.rs index 2527a51..151f1af 100644 --- a/serf-core/src/serf/delegate.rs +++ b/serf-core/src/serf/delegate.rs @@ -252,7 +252,7 @@ where } } MessageRef::QueryResponse(qr) => { - tracing::debug!("serf: query response message: {:?}", qr.from()); + tracing::debug!("serf: query response message"); if let Err(e) = this.handle_query_response(qr).await { tracing::warn!(err=%e, "serf: failed to decode query response message"); } diff --git a/serf-core/src/serf/internal_query.rs b/serf-core/src/serf/internal_query.rs index 537c022..d55711b 100644 --- a/serf-core/src/serf/internal_query.rs +++ b/serf-core/src/serf/internal_query.rs @@ -458,7 +458,7 @@ where #[cfg(feature = "encryption")] async fn send_key_response(q: &QueryEvent, resp: &mut KeyResponseMessage) { match q.name.as_str() { - "_serf_list_keys" => { + crate::event::INTERNAL_LIST_KEYS => { let (raw, qresp) = match Self::key_list_response_with_correct_size(q, resp) { Ok((raw, qresp)) => (raw, qresp), Err(e) => { From 6fcc725ec6e76611da4866e4c78480a40b0d8713 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 20:26:41 +0800 Subject: [PATCH 24/39] Update net.yml --- .github/workflows/net.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/net.yml b/.github/workflows/net.yml index 821eae0..c3ade14 100644 --- a/.github/workflows/net.yml +++ b/.github/workflows/net.yml @@ -36,7 +36,6 @@ jobs: runtime: [tokio, async-std, smol] stream_layer: - tls - - native-tls - tcp steps: - uses: actions/checkout@v4 @@ -67,7 +66,7 @@ jobs: - name: Run Unit Tests for serf based on net transport run: | - cargo test --no-default-features --features "test,encryption,net,metrics,${{ matrix.runtime }}" -- --test-threads=1 + cargo test --no-default-features --features "test,encryption,net,metrics,${{ matrix.runtime }},${{ matrix.stream_layer }}" -- --test-threads=1 working-directory: serf - name: Cache Cargo registry From 66b88f21680d3c4e8985eadef6a60a1011a40c1b Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 20:55:43 +0800 Subject: [PATCH 25/39] Update CI --- .github/workflows/coverage.yml | 4 ++-- serf-core/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index b9bd549..c90cc4b 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -32,10 +32,10 @@ jobs: matrix: include: - crate: core - features: "test,metrics,quickcheck,encryption" + features: "test,metrics,quickcheck,encryption,tracing" name: "serf-core" - crate: serf - features: "test,tokio,tcp,encryption,metrics" + features: "test,tokio,tcp,encryption,metrics,tracing" name: "serf-tcp-encryption" steps: - uses: actions/checkout@v4 diff --git a/serf-core/Cargo.toml b/serf-core/Cargo.toml index 4f05d92..c44b433 100644 --- a/serf-core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -91,7 +91,7 @@ agnostic-lite = { version = "0.5", features = ["tokio"] } tokio = { version = "1", features = ["full"] } futures = { workspace = true, features = ["executor"] } tempfile = "3" -memberlist-core = { workspace = true, features = ["quickcheck", "arbitrary"] } +memberlist-core = { workspace = true, features = ["quickcheck", "arbitrary", "test"] } quickcheck_macros = "1" quickcheck.workspace = true paste = "1" From f4dd7c9c7fa8b1b4995c0521982c616e7c31ff78 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 21:16:57 +0800 Subject: [PATCH 26/39] Update CI --- .github/workflows/coverage.yml | 6 +-- .github/workflows/quinn.yml.bk | 76 ---------------------------------- .github/workflows/s2n.yml.bk | 76 ---------------------------------- serf-core/Cargo.toml | 1 - 4 files changed, 3 insertions(+), 156 deletions(-) delete mode 100644 .github/workflows/quinn.yml.bk delete mode 100644 .github/workflows/s2n.yml.bk diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index c90cc4b..36d38ef 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -31,11 +31,11 @@ jobs: strategy: matrix: include: - - crate: core - features: "test,metrics,quickcheck,encryption,tracing" + - crate: serf-core + features: "test,metrics,quickcheck,encryption" name: "serf-core" - crate: serf - features: "test,tokio,tcp,encryption,metrics,tracing" + features: "test,tokio,tcp,encryption,metrics" name: "serf-tcp-encryption" steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/quinn.yml.bk b/.github/workflows/quinn.yml.bk deleted file mode 100644 index 7efc750..0000000 --- a/.github/workflows/quinn.yml.bk +++ /dev/null @@ -1,76 +0,0 @@ -name: quinn - -on: - push: - branches: - - main - paths-ignore: - - 'README.md' - - 'COPYRIGHT' - - 'LICENSE*' - - '**.md' - - '**.txt' - - 'art' - pull_request: - paths-ignore: - - 'README.md' - - 'COPYRIGHT' - - 'LICENSE*' - - '**.md' - - '**.txt' - - 'art' - workflow_dispatch: - schedule: [cron: "40 1 * * *"] - -jobs: - test: - name: ${{ matrix.os }} - ${{ matrix.runtime }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: - - ubuntu-latest - # - macos-latest, - # - windows-latest - runtime: [tokio, async-std, smol] - steps: - - uses: actions/checkout@v4 - - - name: Install Rust - run: | - rustup update stable && rustup default stable - rustup component add clippy - rustup component add rustfmt - - - name: Install OpenSSL (Windows) - if: matrix.os == 'windows-latest' - shell: powershell - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x64-windows-static-md - - - name: Setup loopback interface (Windows) - if: matrix.os == 'windows-latest' - shell: powershell - run: ci\setup_subnet_windows.ps1 - - name: Setup loopback interface (MacOS) - if: matrix.os == 'macos-latest' - run: ci/setup_subnet_macos.sh - - name: Setup loopback interface (Ubuntu) - if: matrix.os == 'ubuntu-latest' - run: ci/setup_subnet_ubuntu.sh - - - name: Run Unit Tests for serf based on quinn transport - run: | - cargo test --no-default-features --features "test,encryption,quinn,metrics,${{ matrix.runtime }}" -- --test-threads=1 - working-directory: serf - - - name: Cache Cargo registry - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}-${{ matrix.runtime }}-quinn diff --git a/.github/workflows/s2n.yml.bk b/.github/workflows/s2n.yml.bk deleted file mode 100644 index 978761d..0000000 --- a/.github/workflows/s2n.yml.bk +++ /dev/null @@ -1,76 +0,0 @@ -name: s2n - -on: - push: - branches: - - main - paths-ignore: - - 'README.md' - - 'COPYRIGHT' - - 'LICENSE*' - - '**.md' - - '**.txt' - - 'art' - pull_request: - paths-ignore: - - 'README.md' - - 'COPYRIGHT' - - 'LICENSE*' - - '**.md' - - '**.txt' - - 'art' - workflow_dispatch: - schedule: [cron: "40 1 * * *"] - -jobs: - test: - name: ${{ matrix.os }} - ${{ matrix.runtime }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: - - ubuntu-latest - # - macos-latest, - # - windows-latest - runtime: [tokio] - steps: - - uses: actions/checkout@v4 - - - name: Install Rust - run: | - rustup update stable && rustup default stable - rustup component add clippy - rustup component add rustfmt - - - name: Install OpenSSL (Windows) - if: matrix.os == 'windows-latest' - shell: powershell - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x64-windows-static-md - - - name: Setup loopback interface (Windows) - if: matrix.os == 'windows-latest' - shell: powershell - run: ci\setup_subnet_windows.ps1 - - name: Setup loopback interface (MacOS) - if: matrix.os == 'macos-latest' - run: ci/setup_subnet_macos.sh - - name: Setup loopback interface (Ubuntu) - if: matrix.os == 'ubuntu-latest' - run: ci/setup_subnet_ubuntu.sh - - - name: Run Unit Tests for serf based on s2n transport - run: | - cargo test --no-default-features --features "test,encryption,s2n,metrics,${{ matrix.runtime }}" -- --test-threads=1 - working-directory: serf - - - name: Cache Cargo registry - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}-${{ matrix.runtime }}-s2n diff --git a/serf-core/Cargo.toml b/serf-core/Cargo.toml index c44b433..7cc27f0 100644 --- a/serf-core/Cargo.toml +++ b/serf-core/Cargo.toml @@ -43,7 +43,6 @@ test = ["memberlist-core/test", "paste", "tracing-subscriber", "tempfile"] arbitrary = ["dep:arbitrary", "memberlist-core/arbitrary", "smol_str/arbitrary"] quickcheck = ["dep:quickcheck", "memberlist-core/quickcheck"] - [dependencies] auto_impl = "1" atomic_refcell = "0.1" From 2d3f18f483fd944d114a9d05dd32a972e6dcede2 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 21:26:12 +0800 Subject: [PATCH 27/39] Update Cargo.toml --- Cargo.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3bb0af0..082b541 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,12 +38,12 @@ quickcheck = { version = "1", default-features = false } # memberlist-core = { version = "0.6", path = "../memberlist/memberlist-core", default-features = false } # memberlist = { version = "0.6", path = "../memberlist/memberlist", default-features = false } -memberlist-proto = { version = "0.1", default-features = false, git = "https://github.com/al8n/memberlist" } -memberlist-core = { version = "0.6", default-features = false, git = "https://github.com/al8n/memberlist" } -memberlist = { version = "0.6", default-features = false, git = "https://github.com/al8n/memberlist" } +# memberlist-proto = { version = "0.1", default-features = false, git = "https://github.com/al8n/memberlist" } +# memberlist-core = { version = "0.6", default-features = false, git = "https://github.com/al8n/memberlist" } +# memberlist = { version = "0.6", default-features = false, git = "https://github.com/al8n/memberlist" } -# memberlist-proto = { version = "0.1", default-features = false } -# memberlist-core = { version = "0.6", default-features = false } -# memberlist = { version = "0.6", default-features = false } +memberlist-proto = { version = "0.1", default-features = false } +memberlist-core = { version = "0.6", default-features = false } +memberlist = { version = "0.6", default-features = false } serf-core = { path = "serf-core", version = "0.3.0", default-features = false } From d5fd35c74d29ea3b6210e706d7b3ff6a39118c37 Mon Sep 17 00:00:00 2001 From: al8n Date: Mon, 3 Mar 2025 21:30:22 +0800 Subject: [PATCH 28/39] Update README.md --- README.md | 103 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 57552f7..ac91e70 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,70 @@ serf is eventually consistent but converges quickly on average. The speed at whi serf is WASM/WASI friendly, all crates can be compiled to `wasm-wasi` and `wasm-unknown-unknown` (need to configure the crate features). -### Design +## Installation + +- By using `TCP/UDP`, `TLS/UDP` transport + + ```toml + serf = { version = "0.3", features = [ + "tcp", + # Enable a checksum, as UDP is not reliable. + # Built in supports are: "crc32", "xxhash64", "xxhash32", "xxhash3", "murmur3" + "crc32", + # Enable a compression, this is optional, + # and possible values are `snappy`, `brotli`, `zstd` and `lz4`. + # You can enable all. + "snappy", + # Enable encryption, this is optional, + "encryption", + # Enable a async runtime + # Builtin supports are `tokio`, `smol`, `async-std` + "tokio", + # Enable one tls implementation. This is optional. + # Users can just use encryption feature with plain TCP. + # + # "tls", + ] } + ``` + +- By using `QUIC/QUIC` transport + + For `QUIC/QUIC` transport, as QUIC is secure and reliable, so enable checksum or encryption makes no sense. + + ```toml + serf = { version = "0.3", features = [ + # Enable a compression, this is optional, + # and possible values are `snappy`, `brotli`, `zstd` and `lz4`. + # You can enable all. + "snappy", + # Enable a async runtime + # Builtin supports are `tokio`, `smol`, `async-std` + "tokio", + # Enable one of the QUIC implementation + # Builtin support is `quinn` + "quinn", + ] } + ``` + +## Protocol + +serf is based on ["SWIM: Scalable Weakly-consistent Infection-style Process Group Membership Protocol"](http://ieeexplore.ieee.org/document/1028914/). However, Hashicorp developers extends the protocol in a number of ways: + +Several extensions are made to increase propagation speed and convergence rate. +Another set of extensions, that Hashicorp developers call Lifeguard, are made to make serf more robust in the presence of slow message processing (due to factors such as CPU starvation, and network delay or loss). +For details on all of these extensions, please read Hashicorp's paper ["Lifeguard : SWIM-ing with Situational Awareness"](https://arxiv.org/abs/1707.00788), along with the serf source. + +## Q & A + +- ***Does Rust's serf implemenetation compatible to Go's serf?*** + + No! Rust implementation use protobuf-like encoding/decoding whereas Go's implementation uses message pack. + +- ***If Go's serf adds more functionalities, will this project also support?*** + + Yes! And this project may also add more functionalities whereas the Go's serf does not have. e.g. wasmer support, bindings to other languages and etc. + +## Design Unlike the original Go implementation, Rust's serf use highly generic and layered architecture, users can easily implement a component by themselves and plug it to the serf. Users can even custom their own `Id` and `Address`. @@ -57,19 +120,17 @@ Here are the layers: - **[`NetTransport`](https://docs.rs/serf-net/struct.NetTransport.html)** - Three kinds of different builtin stream layers for `NetTransport`: + Builtin stream layers for `NetTransport`: - [`Tcp`](https://docs.rs/serf-net/stream_layer/tcp/struct.Tcp.html): based on TCP and UDP - [`Tls`](https://docs.rs/serf-net/stream_layer/tls/struct.Tls.html): based on [`rustls`](https://docs.rs/rustls) and UDP - - [`NativeTls`](https://docs.rs/serf-net/stream_layer/tls/struct.NativeTls.html): based on [`native-tls`](https://docs.rs/native-tls) and UDP - **[`QuicTransport`](https://docs.rs/serf-quic/struct.QuicTransport.html)** QUIC transport is an experimental transport implementation, it is well tested but still experimental. - Two kinds of different builtin stream layers for `QuicTransport`: + Builtin stream layers for `QuicTransport`: - [`Quinn`](https://docs.rs/serf-quic/stream_layer/quinn/struct.Quinn.html): based on [`quinn`](https://docs.rs/quinn) - - [`S2n`](https://docs.rs/serf-quic/stream_layer/s2n/struct.S2n.html): based on [`s2n-quic`](https://docs.rs/s2n-quic) Users can still implement their own stream layer for different kinds of transport implementations. @@ -87,10 +148,6 @@ Here are the layers: Used to involve a client in a potential cluster merge operation. Namely, when a node does a promised push/pull (as part of a join), the delegate is involved and allowed to cancel the join based on custom logic. The merge delegate is NOT invoked as part of the push-pull anti-entropy. - - **``** - - A delegate for encoding and decoding. Used to control how `serf` should encode/decode messages. - - **`ReconnectDelegate`** Used to custom reconnect behavior, users can implement to allow overriding the reconnect timeout for individual members. @@ -99,36 +156,10 @@ Here are the layers: CompositeDelegate is a helpful struct to split the `Delegate` into multiple small delegates, so that users do not need to implement full `Delegate` when they only want to custom some methods in the Delegate. -### Protocol - -serf is based on ["SWIM: Scalable Weakly-consistent Infection-style Process Group Membership Protocol"](http://ieeexplore.ieee.org/document/1028914/). However, Hashicorp developers extends the protocol in a number of ways: - -Several extensions are made to increase propagation speed and convergence rate. -Another set of extensions, that Hashicorp developers call Lifeguard, are made to make serf more robust in the presence of slow message processing (due to factors such as CPU starvation, and network delay or loss). -For details on all of these extensions, please read Hashicorp's paper ["Lifeguard : SWIM-ing with Situational Awareness"](https://arxiv.org/abs/1707.00788), along with the serf source. - -## Installation - -```toml -[dependencies] -serf = "0.2" -``` - -## Q & A - -- ***Does Rust's serf implemenetation compatible to Go's serf?*** - - No but yes! By default, it is not compatible. But the secret is the serialize/deserilize layer, Go's serf use the msgpack as the serialization/deserialization framework, so in theory, if you can implement a [``](https://docs.rs/serf-core/transport/trait..html) trait which compat to Go's serf, then it becomes compatible. - -- ***If Go's serf adds more functionalities, will this project also support?*** - - Yes! And this project may also add more functionalities whereas the Go's serf does not have. e.g. wasmer support, bindings to other languages and etc. - ## Related Projects - [`agnostic`](https://github.com/al8n/agnostic): helps you to develop runtime agnostic crates - [`nodecraft`](https://github.com/al8n/nodecraft): crafting seamless node operations for distributed systems, which provides foundational traits for node identification and address resolution. -- [`transformable`](https://github.com/al8n/transformable): transform its representation between structured and byte form. - [`peekable`](https://github.com/al8n/peekable): peekable reader and async reader - [`memberlist`](https://github.com/al8n/memberlist): A highly customable, adaptable, runtime agnostic and WASM/WASI friendly Gossip protocol which helps manage cluster membership and member failure detection. @@ -138,7 +169,7 @@ serf = "0.2" See [LICENSE](./LICENSE) for details. -Copyright (c) 2024 Al Liu. +Copyright (c) 2025 Al Liu. Copyright (c) 2013 HashiCorp, Inc. From bea0ef606734b5af18d2771276113a0932ea0133 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 13:07:02 +0800 Subject: [PATCH 29/39] Add example --- Cargo.toml | 2 + README.md | 5 + examples/toy-consul/Cargo.toml | 16 + examples/toy-consul/README.md | 68 +++++ examples/toy-consul/src/main.rs | 518 ++++++++++++++++++++++++++++++++ serf-core/src/types.rs | 6 +- 6 files changed, 613 insertions(+), 2 deletions(-) create mode 100644 examples/toy-consul/Cargo.toml create mode 100644 examples/toy-consul/README.md create mode 100644 examples/toy-consul/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 082b541..0e3c6ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "serf", "serf-core", + "examples/toy-consul", ] resolver = "3" @@ -47,3 +48,4 @@ memberlist-core = { version = "0.6", default-features = false } memberlist = { version = "0.6", default-features = false } serf-core = { path = "serf-core", version = "0.3.0", default-features = false } +serf = { path = "serf", version = "0.3.0", default-features = false } diff --git a/README.md b/README.md index ac91e70..2e7958c 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,10 @@ serf is WASM/WASI friendly, all crates can be compiled to `wasm-wasi` and `wasm- ] } ``` +## Examples + +See [examples/toyconsul](./examples/toy-consul/). + ## Protocol serf is based on ["SWIM: Scalable Weakly-consistent Infection-style Process Group Membership Protocol"](http://ieeexplore.ieee.org/document/1028914/). However, Hashicorp developers extends the protocol in a number of ways: @@ -159,6 +163,7 @@ Here are the layers: ## Related Projects - [`agnostic`](https://github.com/al8n/agnostic): helps you to develop runtime agnostic crates +- [`getifs`](https://github.com/al8n/getifs): A bunch of cross platform network tools for fetching interfaces, multicast addresses, local ip addresses, private ip addresses, public ip addresses and etc. - [`nodecraft`](https://github.com/al8n/nodecraft): crafting seamless node operations for distributed systems, which provides foundational traits for node identification and address resolution. - [`peekable`](https://github.com/al8n/peekable): peekable reader and async reader - [`memberlist`](https://github.com/al8n/memberlist): A highly customable, adaptable, runtime agnostic and WASM/WASI friendly Gossip protocol which helps manage cluster membership and member failure detection. diff --git a/examples/toy-consul/Cargo.toml b/examples/toy-consul/Cargo.toml new file mode 100644 index 0000000..a12665a --- /dev/null +++ b/examples/toy-consul/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "toy-consul" +rust-version = "1.85.0" +edition = "2024" +publish = false + +[dependencies] +bincode = "1" +clap = { version = "4", features = ["derive"] } +crossbeam-skiplist = "0.1" +serf = { workspace = true, features = ["default", "tokio", "tcp", "serde"] } +serde = { version = "1", features = ["derive"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tokio = { version = "1", features = ["full"] } +scopeguard = "1" diff --git a/examples/toy-consul/README.md b/examples/toy-consul/README.md new file mode 100644 index 0000000..b62a222 --- /dev/null +++ b/examples/toy-consul/README.md @@ -0,0 +1,68 @@ +# ToyConsul + +A toy eventually consensus distributed database. + +## Installation + +```bash +cargo install --path . +``` + +## Run + +- In the first terminal + + ```bash + toyconsul start --id instance1 --addr 127.0.0.1:7001 --meta instance1 --rpc-addr toyconsul.instance1.sock + ``` + +- In the second terminal + + - Start instance 2 + + ```bash + toyconsul start --id instance2 --addr 127.0.0.1:7002 --meta instance2 --rpc-addr toyconsul.instance2.sock + ``` + + - Send the join command to instance2 and let it join to instance1 + + ```bash + toyconsul join --id instance1 --addr 127.0.0.1:7001 --rpc-addr toyconsul.instance2.sock + ``` + +- In the third terminal + + - Start instance 3 + + ```bash + toyconsul start --id instance3 --addr 127.0.0.1:7003 --meta instance3 --rpc-addr toyconsul.instance3.sock + ``` + + - Send the join command to instance3 and let it join to instance1 (can also join to instance 2) + + ```bash + toyconsul join --id instance1 --addr 127.0.0.1:7001 --rpc-addr toyconsul.instance3.sock + ``` + +- In the fourth terminal + + - Insert a key - value to the instance1 + + ```bash + toyconsul register --name web --addr 192.0.0.1:8080 --rpc-addr toyconsul.instance1.sock + toyconsul register --name db --addr 192.0.0.1:8081 --rpc-addr toyconsul.instance2.sock + ``` + + - After some seconds, you can get the value from any one of three instances + + ```bash + toyconsul list --rpc-addr toyconsul.instance1.sock + ``` + + ```bash + toyconsul list --rpc-addr toyconsul.instance2.sock + ``` + + ```bash + toyconsul list --rpc-addr toyconsul.instance3.sock + ``` diff --git a/examples/toy-consul/src/main.rs b/examples/toy-consul/src/main.rs new file mode 100644 index 0000000..1774fa0 --- /dev/null +++ b/examples/toy-consul/src/main.rs @@ -0,0 +1,518 @@ +use std::{net::SocketAddr, sync::Arc}; + +use bincode::{deserialize, serialize}; +use clap::Parser; +use crossbeam_skiplist::SkipMap; +use serf::{ + MemberlistOptions, Options, + agnostic::tokio::TokioRuntime, + delegate::CompositeDelegate, + net::{ + NetTransportOptions, Node, NodeId, resolver::socket_addr::SocketAddrResolver, + stream_layer::tcp::Tcp, + }, + tokio::TokioTcpSerf, + types::{MaybeResolvedAddress, SmolStr}, +}; + +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{UnixListener, UnixStream}, + sync::{ + mpsc::{UnboundedReceiver, UnboundedSender}, + oneshot, + }, +}; + +type Result = std::result::Result>; + +type ConsulDelegate = CompositeDelegate; + +#[derive(Clone, serde::Serialize, serde::Deserialize)] +struct Service { + name: SmolStr, + addr: SocketAddr, +} + +struct Inner { + serf: TokioTcpSerf, ConsulDelegate>, + services: SkipMap, + tx: UnboundedSender, +} + +#[derive(Clone)] +struct ToyConsul { + inner: Arc, +} + +impl ToyConsul { + async fn new( + opts: Options, + net_opts: NetTransportOptions, Tcp>, + ) -> Result { + let serf = TokioTcpSerf::new(net_opts, opts.with_event_buffer_size(256)).await?; + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let this = Self { + inner: Inner { + serf, + services: SkipMap::new(), + tx, + } + .into(), + }; + + this.clone().handle_events(rx); + + Ok(this) + } + + fn handle_events(self, mut rx: UnboundedReceiver) { + tokio::spawn(async move { + loop { + tokio::select! { + _ = tokio::signal::ctrl_c() => { + tracing::info!("toyconsul: shutting down event listener"); + } + ev = rx.recv() => { + if let Some(ev) = ev { + match ev { + Event::Join { id, addr, tx } => { + let res = self.inner.serf.join(Node::new(id, MaybeResolvedAddress::Resolved(addr)), false).await; + let _ = tx.send(res.map_err(Into::into).map(|_| ())); + } + Event::Register { name, addr, tx } => { + self.inner.services.insert( + name.clone(), + Service { + name, + addr, + }); + let _ = tx.send(Ok(())); + } + Event::List { tx } => { + let services = self.inner.services.iter().map(|ent| ent.value().clone()).collect(); + let _ = tx.send(Ok(services)); + } + } + } + } + } + } + }); + } + + async fn handle_register( + &self, + name: SmolStr, + addr: SocketAddr, + stream: &mut W, + ) -> Result<()> { + let (tx, rx) = oneshot::channel(); + if let Err(e) = self.inner.tx.send(Event::Register { name, addr, tx }) { + tracing::error!(err=%e, "toyconsul: fail to send get event"); + return Ok(()); + } + + let resp = rx.await?; + tracing::info!(value=?resp, "toyconsul: fetch key"); + match bincode::serialize(&resp.map_err(|e| e.to_string())) { + Ok(resp) => { + let mut prefixed_data = vec![0; resp.len() + 4]; + prefixed_data[..4].copy_from_slice(&(resp.len() as u32).to_le_bytes()); + prefixed_data[4..].copy_from_slice(&resp); + if let Err(e) = stream.write_all(&prefixed_data).await { + tracing::error!(err=%e, "toyconsul: fail to write rpc response"); + } else { + tracing::info!(data=?prefixed_data, "toyconsul: send get response"); + } + } + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to encode rpc response"); + } + } + Ok(()) + } + + async fn handle_join( + &self, + id: NodeId, + addr: SocketAddr, + stream: &mut W, + ) -> Result<()> { + let (tx, rx) = oneshot::channel(); + self.inner.tx.send(Event::Join { + id: id.clone(), + addr, + tx, + })?; + + let resp = rx.await?; + if let Err(e) = resp { + let res = std::result::Result::<(), String>::Err(e.to_string()); + match bincode::serialize(&res) { + Ok(resp) => { + let mut prefixed_data = vec![0; resp.len() + 4]; + prefixed_data[..4].copy_from_slice(&(resp.len() as u32).to_le_bytes()); + prefixed_data[4..].copy_from_slice(&resp); + if let Err(e) = stream.write_all(&prefixed_data).await { + tracing::error!(err=%e, "toyconsul: fail to write rpc response"); + } + } + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to encode rpc response"); + } + } + } else { + let res = std::result::Result::<(), String>::Ok(()); + match bincode::serialize(&res) { + Ok(resp) => { + let mut prefixed_data = vec![0; resp.len() + 4]; + prefixed_data[..4].copy_from_slice(&(resp.len() as u32).to_le_bytes()); + prefixed_data[4..].copy_from_slice(&resp); + if let Err(e) = stream.write_all(&prefixed_data).await { + tracing::error!(err=%e, "toyconsul: fail to write rpc response"); + } + } + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to encode rpc response"); + } + } + } + + Ok(()) + } + + async fn handle_list(&self, stream: &mut W) -> Result<()> { + let (tx, rx) = oneshot::channel(); + self.inner.tx.send(Event::List { tx })?; + + let resp = rx.await?; + + match resp { + Ok(s) => { + let res = std::result::Result::, String>::Ok(s); + match bincode::serialize(&res) { + Ok(resp) => { + let mut prefixed_data = vec![0; resp.len() + 4]; + prefixed_data[..4].copy_from_slice(&(resp.len() as u32).to_le_bytes()); + prefixed_data[4..].copy_from_slice(&resp); + if let Err(e) = stream.write_all(&prefixed_data).await { + tracing::error!(err=%e, "toyconsul: fail to write rpc response"); + } + } + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to encode rpc response"); + } + } + } + Err(e) => { + let res = std::result::Result::<(), String>::Err(e.to_string()); + match bincode::serialize(&res) { + Ok(resp) => { + let mut prefixed_data = vec![0; resp.len() + 4]; + prefixed_data[..4].copy_from_slice(&(resp.len() as u32).to_le_bytes()); + prefixed_data[4..].copy_from_slice(&resp); + if let Err(e) = stream.write_all(&prefixed_data).await { + tracing::error!(err=%e, "toyconsul: fail to write rpc response"); + } + } + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to encode rpc response"); + } + } + } + } + Ok(()) + } +} + +#[derive(clap::Args)] +struct StartArgs { + /// The id of the db instance + #[clap(short, long)] + id: NodeId, + /// The address the memberlist should bind to + #[clap(short, long)] + addr: SocketAddr, + /// The rpc address to listen on commands + #[clap(short, long)] + rpc_addr: std::path::PathBuf, +} + +#[derive(clap::Subcommand)] +enum Commands { + /// Start the toyconsul instance + Start(StartArgs), + /// Join to an existing toyconsul cluster + Join { + #[clap(short, long)] + id: NodeId, + #[clap(short, long)] + addr: SocketAddr, + #[clap(short, long)] + rpc_addr: std::path::PathBuf, + }, + /// Register a service to the toyconsul + Register { + #[clap(short, long)] + name: String, + #[clap(short, long)] + addr: SocketAddr, + #[clap(short, long)] + rpc_addr: std::path::PathBuf, + }, + /// List all services in the toyconsul + List, +} + +#[derive(clap::Parser)] +#[command(name = "toyconsul")] +#[command(about = "CLI for toyconsul example", long_about = None)] +struct Cli { + #[clap(subcommand)] + command: Commands, +} + +#[derive(serde::Serialize, serde::Deserialize)] +enum Op { + Register { name: SmolStr, addr: SocketAddr }, + List, + Join { addr: SocketAddr, id: NodeId }, +} + +enum Event { + Register { + name: SmolStr, + addr: SocketAddr, + tx: oneshot::Sender>, + }, + List { + tx: oneshot::Sender>>, + }, + Join { + addr: SocketAddr, + id: NodeId, + tx: oneshot::Sender>, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + let filter = std::env::var("TOY_CONSUL_LOG").unwrap_or_else(|_| "info".to_owned()); + tracing::subscriber::set_global_default( + tracing_subscriber::fmt::fmt() + .without_time() + .with_line_number(true) + .with_env_filter(filter) + .with_file(false) + .with_target(true) + .with_ansi(true) + .finish(), + ) + .unwrap(); + + let cli = Cli::parse(); + match cli.command { + Commands::Join { addr, id, rpc_addr } => { + handle_join_cmd(id, addr, rpc_addr).await?; + } + Commands::Register { + name, + addr, + rpc_addr, + } => { + handle_register_cmd(name, addr, rpc_addr).await?; + } + Commands::Start(args) => { + handle_start_cmd(args).await?; + } + Commands::List => { + handle_list_cmd().await?; + } + } + + Ok(()) +} + +async fn handle_join_cmd(id: NodeId, addr: SocketAddr, rpc_addr: std::path::PathBuf) -> Result<()> { + let conn = UnixStream::connect(rpc_addr).await?; + let data = serialize(&Op::Join { id, addr })?; + + let (reader, mut writer) = conn.into_split(); + + let mut prefixed_data = vec![0; data.len() + 4]; + prefixed_data[..4].copy_from_slice(&(data.len() as u32).to_le_bytes()); + prefixed_data[4..].copy_from_slice(&data); + + writer.write_all(&prefixed_data).await?; + writer.shutdown().await?; + + let mut reader = tokio::io::BufReader::new(reader); + let mut len_buf = [0; 4]; + reader.read_exact(&mut len_buf).await?; + let len = u32::from_le_bytes(len_buf) as usize; + + let mut buf = vec![0; len]; + reader.read_exact(&mut buf).await?; + let res = deserialize::>(&buf)?; + match res { + Ok(_) => { + println!("join successfully"); + } + Err(e) => { + println!("fail to join {e}") + } + } + Ok(()) +} + +async fn handle_register_cmd( + name: String, + addr: SocketAddr, + rpc_addr: std::path::PathBuf, +) -> Result<()> { + let conn = UnixStream::connect(rpc_addr).await?; + let data = serialize(&Op::Register { + name: name.clone().into(), + addr, + })?; + + let (reader, mut writer) = conn.into_split(); + + let mut prefixed_data = vec![0; data.len() + 4]; + prefixed_data[..4].copy_from_slice(&(data.len() as u32).to_le_bytes()); + prefixed_data[4..].copy_from_slice(&data); + + writer.write_all(&prefixed_data).await?; + writer.shutdown().await?; + + let mut reader = tokio::io::BufReader::new(reader); + let mut len_buf = [0; 4]; + reader.read_exact(&mut len_buf).await?; + let len = u32::from_le_bytes(len_buf) as usize; + + let mut buf = vec![0; len]; + reader.read_exact(&mut buf).await?; + let res = deserialize::>(&buf)?; + match res { + Ok(_) => { + println!("register {}({}) successfully", name, addr); + } + Err(e) => { + println!("fail to register {e}"); + } + } + Ok(()) +} + +async fn handle_start_cmd(args: StartArgs) -> Result<()> { + let opts = Options::new().with_memberlist_options(MemberlistOptions::local()); + let net_opts = + NetTransportOptions::new(args.id).with_bind_addresses([args.addr].into_iter().collect()); + + let consul = ToyConsul::new(opts, net_opts).await?; + + struct Guard { + sock: std::path::PathBuf, + } + + impl Drop for Guard { + fn drop(&mut self) { + if let Err(e) = std::fs::remove_file(&self.sock) { + tracing::error!(err=%e, "toyconsul: fail to remove rpc sock"); + } + } + } + + let _guard = Guard { + sock: args.rpc_addr.clone(), + }; + + let listener = UnixListener::bind(&args.rpc_addr)?; + + tracing::info!("toyconsul: start listening on {}", args.rpc_addr.display()); + + loop { + tokio::select! { + conn = listener.accept() => { + let (stream, _) = conn?; + let mut stream = tokio::io::BufReader::new(stream); + let mut len_buf = [0; 4]; + stream.read_exact(&mut len_buf).await?; + let len = u32::from_le_bytes(len_buf) as usize; + + let mut data = vec![0; len]; + if let Err(e) = stream.read_exact(&mut data).await { + tracing::error!(err=%e, "toyconsul: fail to read from rpc stream"); + continue; + } + + let op: Op = match bincode::deserialize(&data) { + Ok(op) => op, + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to decode rpc message"); + continue; + } + }; + + match op { + Op::Join { addr, id } => { + consul.handle_join(id, addr, &mut stream).await?; + } + Op::Register { + name, + addr, + } => { + consul.handle_register(name, addr, &mut stream).await?; + }, + Op::List => { + consul.handle_list(&mut stream).await?; + } + } + + if let Err(e) = stream.into_inner().shutdown().await { + tracing::error!(err=%e, "toyconsul: fail to shutdown rpc stream"); + } + } + _ = tokio::signal::ctrl_c() => { + break; + } + } + } + Ok(()) +} + +async fn handle_list_cmd() -> Result<()> { + let conn = UnixStream::connect("/tmp/toyconsul.sock").await?; + let data = serialize(&Op::List)?; + + let (reader, mut writer) = conn.into_split(); + + let mut prefixed_data = vec![0; data.len() + 4]; + prefixed_data[..4].copy_from_slice(&(data.len() as u32).to_le_bytes()); + prefixed_data[4..].copy_from_slice(&data); + + writer.write_all(&prefixed_data).await?; + writer.shutdown().await?; + + let mut reader = tokio::io::BufReader::new(reader); + let mut len_buf = [0; 4]; + reader.read_exact(&mut len_buf).await?; + let len = u32::from_le_bytes(len_buf) as usize; + + let mut buf = vec![0; len]; + reader.read_exact(&mut buf).await?; + let res = deserialize::, String>>(&buf)?; + match res { + Ok(services) => { + for service in services { + println!("{}({})", service.name, service.addr); + } + } + Err(e) => { + println!("fail to list {e}") + } + } + Ok(()) +} diff --git a/serf-core/src/types.rs b/serf-core/src/types.rs index d79084d..fd5686d 100644 --- a/serf-core/src/types.rs +++ b/serf-core/src/types.rs @@ -1,9 +1,11 @@ use std::time::Duration; pub use memberlist_core::proto::{ - DelegateVersion as MemberlistDelegateVersion, Domain, HostAddr, Node, NodeId, ParseDomainError, - ParseHostAddrError, ParseNodeIdError, ProtocolVersion as MemberlistProtocolVersion, + DelegateVersion as MemberlistDelegateVersion, Domain, HostAddr, MaybeResolvedAddress, Node, + NodeId, ParseDomainError, ParseHostAddrError, ParseNodeIdError, + ProtocolVersion as MemberlistProtocolVersion, bytes, }; +pub use smol_str::*; #[cfg(feature = "encryption")] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] From 3cbb8b84550f2eb6d46412d7ab419cd64f890c15 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 17:04:44 +0800 Subject: [PATCH 30/39] Update --- README.md | 11 +++++++---- examples/toy-consul/src/main.rs | 12 ++++-------- serf/Cargo.toml | 2 +- serf/src/async_std.rs | 32 ++++---------------------------- serf/src/smol.rs | 32 ++++---------------------------- serf/src/tokio.rs | 32 ++++---------------------------- 6 files changed, 24 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 2e7958c..66f2e39 100644 --- a/README.md +++ b/README.md @@ -17,11 +17,12 @@ Port and improve [HashiCorp's serf](https://github.com/hashicorp/serf) to Rust. [Build][CI-url] [codecov][codecov-url] -[docs.rs][doc-url] +[docs.rs][doc-url] [crates.io][crates-url] [crates.io][crates-url] license +[github][discord] English | [简体中文][zh-cn-url] @@ -162,10 +163,11 @@ Here are the layers: ## Related Projects -- [`agnostic`](https://github.com/al8n/agnostic): helps you to develop runtime agnostic crates +- [`agnostic`](https://github.com/al8n/agnostic): Helps you to develop runtime agnostic crates +- [`agnostic-mdns`](https://github.com/al8n/agnostic-mdns): Simple and lightweight mDNS client/server library for any async runtime. - [`getifs`](https://github.com/al8n/getifs): A bunch of cross platform network tools for fetching interfaces, multicast addresses, local ip addresses, private ip addresses, public ip addresses and etc. -- [`nodecraft`](https://github.com/al8n/nodecraft): crafting seamless node operations for distributed systems, which provides foundational traits for node identification and address resolution. -- [`peekable`](https://github.com/al8n/peekable): peekable reader and async reader +- [`nodecraft`](https://github.com/al8n/nodecraft): Crafting seamless node operations for distributed systems, which provides foundational traits for node identification and address resolution. +- [`peekable`](https://github.com/al8n/peekable): Peekable reader and async reader. - [`memberlist`](https://github.com/al8n/memberlist): A highly customable, adaptable, runtime agnostic and WASM/WASI friendly Gossip protocol which helps manage cluster membership and member failure detection. #### License @@ -184,3 +186,4 @@ Copyright (c) 2013 HashiCorp, Inc. [crates-url]: https://crates.io/crates/serf [codecov-url]: https://app.codecov.io/gh/al8n/serf/ [zh-cn-url]: https://github.com/al8n/serf/tree/main/README-zh_CN.md +[discord]: https://discord.gg/4JyVhKFcrt diff --git a/examples/toy-consul/src/main.rs b/examples/toy-consul/src/main.rs index 1774fa0..2ef342a 100644 --- a/examples/toy-consul/src/main.rs +++ b/examples/toy-consul/src/main.rs @@ -5,13 +5,9 @@ use clap::Parser; use crossbeam_skiplist::SkipMap; use serf::{ MemberlistOptions, Options, - agnostic::tokio::TokioRuntime, delegate::CompositeDelegate, - net::{ - NetTransportOptions, Node, NodeId, resolver::socket_addr::SocketAddrResolver, - stream_layer::tcp::Tcp, - }, - tokio::TokioTcpSerf, + net::{NetTransportOptions, Node, NodeId}, + tokio::{TokioSocketAddrResolver, TokioTcp, TokioTcpSerf}, types::{MaybeResolvedAddress, SmolStr}, }; @@ -35,7 +31,7 @@ struct Service { } struct Inner { - serf: TokioTcpSerf, ConsulDelegate>, + serf: TokioTcpSerf, services: SkipMap, tx: UnboundedSender, } @@ -48,7 +44,7 @@ struct ToyConsul { impl ToyConsul { async fn new( opts: Options, - net_opts: NetTransportOptions, Tcp>, + net_opts: NetTransportOptions, ) -> Result { let serf = TokioTcpSerf::new(net_opts, opts.with_event_buffer_size(256)).await?; let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); diff --git a/serf/Cargo.toml b/serf/Cargo.toml index 57e9ff7..b426606 100644 --- a/serf/Cargo.toml +++ b/serf/Cargo.toml @@ -40,7 +40,7 @@ quic = ["memberlist/quic"] quinn = ["memberlist/quinn", "quic"] net = ["memberlist/net"] -tcp = ["net"] +tcp = ["memberlist/tcp", "net"] tls = ["memberlist/tls", "net"] # enable DNS node address resolver diff --git a/serf/src/async_std.rs b/serf/src/async_std.rs index a030d4b..64a9021 100644 --- a/serf/src/async_std.rs +++ b/serf/src/async_std.rs @@ -1,4 +1,4 @@ -pub use memberlist::agnostic::async_std::AsyncStdRuntime; +pub use memberlist::async_std::*; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tcp`](memberlist::net::stream_layer::tcp::Tcp) stream layer with `async-std` runtime. #[cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))] @@ -6,38 +6,14 @@ pub use memberlist::agnostic::async_std::AsyncStdRuntime; docsrs, doc(cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))) )] -pub type AsyncStdTcpSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::tcp::Tcp, - AsyncStdRuntime, - >, - D, ->; +pub type AsyncStdTcpSerf = serf_core::Serf, D>; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tls`](memberlist::net::stream_layer::tls::Tls) stream layer with `async-std` runtime. #[cfg(all(feature = "tls", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "tls", not(target_family = "wasm")))))] -pub type AsyncStdTlsSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::tls::Tls, - AsyncStdRuntime, - >, - D, ->; +pub type AsyncStdTlsSerf = serf_core::Serf, D>; /// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`Quinn`](memberlist::quic::stream_layer::quinn::Quinn) stream layer with `async-std` runtime. #[cfg(all(feature = "quinn", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "quinn", not(target_family = "wasm")))))] -pub type AsyncStdQuicSerf = serf_core::Serf< - memberlist::quic::QuicTransport< - I, - A, - memberlist::quic::stream_layer::quinn::Quinn, - AsyncStdRuntime, - >, - D, ->; +pub type AsyncStdQuicSerf = serf_core::Serf, D>; diff --git a/serf/src/smol.rs b/serf/src/smol.rs index 0c56d9b..97c1fc0 100644 --- a/serf/src/smol.rs +++ b/serf/src/smol.rs @@ -1,4 +1,4 @@ -pub use memberlist::agnostic::smol::SmolRuntime; +pub use memberlist::smol::*; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tcp`](memberlist::net::stream_layer::tcp::Tcp) stream layer with `smol` runtime. #[cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))] @@ -6,38 +6,14 @@ pub use memberlist::agnostic::smol::SmolRuntime; docsrs, doc(cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))) )] -pub type SmolTcpSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::tcp::Tcp, - SmolRuntime, - >, - D, ->; +pub type SmolTcpSerf = serf_core::Serf, D>; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tls`](memberlist::net::stream_layer::tls::Tls) stream layer with `smol` runtime. #[cfg(all(feature = "tls", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "tls", not(target_family = "wasm")))))] -pub type SmolTlsSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::tls::Tls, - SmolRuntime, - >, - D, ->; +pub type SmolTlsSerf = serf_core::Serf, D>; /// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`Quinn`](memberlist::quic::stream_layer::quinn::Quinn) stream layer with `smol` runtime. #[cfg(all(feature = "quinn", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "quinn", not(target_family = "wasm")))))] -pub type SmolQuicSerf = serf_core::Serf< - memberlist::quic::QuicTransport< - I, - A, - memberlist::quic::stream_layer::quinn::Quinn, - SmolRuntime, - >, - D, ->; +pub type SmolQuicSerf = serf_core::Serf, D>; diff --git a/serf/src/tokio.rs b/serf/src/tokio.rs index 2e4f8e6..c43f0e2 100644 --- a/serf/src/tokio.rs +++ b/serf/src/tokio.rs @@ -1,4 +1,4 @@ -pub use memberlist::agnostic::tokio::TokioRuntime; +pub use memberlist::tokio::*; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tcp`](memberlist::net::stream_layer::tcp::Tcp) stream layer with `tokio` runtime. #[cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))] @@ -6,38 +6,14 @@ pub use memberlist::agnostic::tokio::TokioRuntime; docsrs, doc(cfg(all(any(feature = "tcp", feature = "tls",), not(target_family = "wasm")))) )] -pub type TokioTcpSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::tcp::Tcp, - TokioRuntime, - >, - D, ->; +pub type TokioTcpSerf = serf_core::Serf, D>; /// [`Serf`](super::Serf) type alias for using [`NetTransport`](memberlist::net::NetTransport) and [`Tls`](memberlist::net::stream_layer::tls::Tls) stream layer with `tokio` runtime. #[cfg(all(feature = "tls", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "tls", not(target_family = "wasm")))))] -pub type TokioTlsSerf = serf_core::Serf< - memberlist::net::NetTransport< - I, - A, - memberlist::net::stream_layer::tls::Tls, - TokioRuntime, - >, - D, ->; +pub type TokioTlsSerf = serf_core::Serf, D>; /// [`Serf`](super::Serf) type alias for using [`QuicTransport`](memberlist::quic::QuicTransport) and [`Quinn`](memberlist::quic::stream_layer::quinn::Quinn) stream layer with `tokio` runtime. #[cfg(all(feature = "quinn", not(target_family = "wasm")))] #[cfg_attr(docsrs, doc(cfg(all(feature = "quinn", not(target_family = "wasm")))))] -pub type TokioQuicSerf = serf_core::Serf< - memberlist::quic::QuicTransport< - I, - A, - memberlist::quic::stream_layer::quinn::Quinn, - TokioRuntime, - >, - D, ->; +pub type TokioQuicSerf = serf_core::Serf, D>; From de65eed5b7add51708e20c5a050f3b4c517246d8 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 22:05:49 +0800 Subject: [PATCH 31/39] Finish example --- examples/toy-consul/Cargo.toml | 2 +- examples/toy-consul/README.md | 6 +- examples/toy-consul/src/main.rs | 115 ++++++++++++++++++++++++-------- serf-core/src/serf/api.rs | 3 +- serf/src/lib.rs | 6 +- 5 files changed, 95 insertions(+), 37 deletions(-) diff --git a/examples/toy-consul/Cargo.toml b/examples/toy-consul/Cargo.toml index a12665a..096b9b8 100644 --- a/examples/toy-consul/Cargo.toml +++ b/examples/toy-consul/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "toy-consul" +name = "toyconsul" rust-version = "1.85.0" edition = "2024" publish = false diff --git a/examples/toy-consul/README.md b/examples/toy-consul/README.md index b62a222..0ad2f35 100644 --- a/examples/toy-consul/README.md +++ b/examples/toy-consul/README.md @@ -13,7 +13,7 @@ cargo install --path . - In the first terminal ```bash - toyconsul start --id instance1 --addr 127.0.0.1:7001 --meta instance1 --rpc-addr toyconsul.instance1.sock + toyconsul start --id instance1 --addr 127.0.0.1:7001 --rpc-addr toyconsul.instance1.sock ``` - In the second terminal @@ -21,7 +21,7 @@ cargo install --path . - Start instance 2 ```bash - toyconsul start --id instance2 --addr 127.0.0.1:7002 --meta instance2 --rpc-addr toyconsul.instance2.sock + toyconsul start --id instance2 --addr 127.0.0.1:7002 --rpc-addr toyconsul.instance2.sock ``` - Send the join command to instance2 and let it join to instance1 @@ -35,7 +35,7 @@ cargo install --path . - Start instance 3 ```bash - toyconsul start --id instance3 --addr 127.0.0.1:7003 --meta instance3 --rpc-addr toyconsul.instance3.sock + toyconsul start --id instance3 --addr 127.0.0.1:7003 --rpc-addr toyconsul.instance3.sock ``` - Send the join command to instance3 and let it join to instance1 (can also join to instance 2) diff --git a/examples/toy-consul/src/main.rs b/examples/toy-consul/src/main.rs index 2ef342a..5b0011a 100644 --- a/examples/toy-consul/src/main.rs +++ b/examples/toy-consul/src/main.rs @@ -1,4 +1,4 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, path::PathBuf, sync::Arc}; use bincode::{deserialize, serialize}; use clap::Parser; @@ -6,7 +6,8 @@ use crossbeam_skiplist::SkipMap; use serf::{ MemberlistOptions, Options, delegate::CompositeDelegate, - net::{NetTransportOptions, Node, NodeId}, + event::{Event as SerfEvent, EventProducer, EventSubscriber}, + net::{NetTransportOptions, Node, NodeId, TokioNetTransport}, tokio::{TokioSocketAddrResolver, TokioTcp, TokioTcpSerf}, types::{MaybeResolvedAddress, SmolStr}, }; @@ -46,8 +47,11 @@ impl ToyConsul { opts: Options, net_opts: NetTransportOptions, ) -> Result { - let serf = TokioTcpSerf::new(net_opts, opts.with_event_buffer_size(256)).await?; let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (producer, subscriber) = EventProducer::unbounded(); + let serf = + TokioTcpSerf::with_event_producer(net_opts, opts.with_event_buffer_size(256), producer) + .await?; let this = Self { inner: Inner { @@ -58,37 +62,89 @@ impl ToyConsul { .into(), }; - this.clone().handle_events(rx); + this.clone().handle_serf_events(subscriber); + this.clone().handle_rpc_events(rx); Ok(this) } - fn handle_events(self, mut rx: UnboundedReceiver) { + fn handle_rpc_events(self, mut rx: UnboundedReceiver) { tokio::spawn(async move { loop { tokio::select! { _ = tokio::signal::ctrl_c() => { - tracing::info!("toyconsul: shutting down event listener"); + tracing::info!("toyconsul: shutting down rpc listener"); } ev = rx.recv() => { - if let Some(ev) = ev { - match ev { - Event::Join { id, addr, tx } => { - let res = self.inner.serf.join(Node::new(id, MaybeResolvedAddress::Resolved(addr)), false).await; - let _ = tx.send(res.map_err(Into::into).map(|_| ())); - } - Event::Register { name, addr, tx } => { - self.inner.services.insert( - name.clone(), - Service { - name, - addr, - }); - let _ = tx.send(Ok(())); + match ev { + Some(Event::Register { name, addr, tx }) => { + let service = Service { name, addr }; + match bincode::serialize(&service) { + Ok(data) => { + // broadcast a register event to all members + match self.inner.serf.user_event("register", data, false).await { + Ok(_) => { + let _ = tx.send(Ok(())); + } + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to send register event"); + let _ = tx.send(Err(e.into())); + } + } + } + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to encode register response"); + let _ = tx.send(Err(e.into())); + } } - Event::List { tx } => { - let services = self.inner.services.iter().map(|ent| ent.value().clone()).collect(); - let _ = tx.send(Ok(services)); + } + Some(Event::List { tx }) => { + let services = self.inner.services.iter().map(|ent| ent.value().clone()).collect(); + let _ = tx.send(Ok(services)); + } + Some(Event::Join { id, addr, tx }) => { + let res = self.inner.serf.join(Node::new(id, MaybeResolvedAddress::Resolved(addr)), false).await; + let _ = tx.send(res.map_err(Into::into).map(|_| ())); + } + None => { + break; + } + } + } + } + } + }); + } + + fn handle_serf_events( + self, + subscriber: EventSubscriber< + TokioNetTransport, + ConsulDelegate, + >, + ) { + tokio::spawn(async move { + loop { + tokio::select! { + _ = tokio::signal::ctrl_c() => { + tracing::info!("toyconsul: shutting down event listener"); + } + ev = subscriber.recv() => { + if let Ok(SerfEvent::User(ev)) = ev { + match ev.name().as_str() { + "register" => { + let payload = ev.payload(); + let service: Service = match bincode::deserialize(payload) { + Ok(service) => service, + Err(e) => { + tracing::error!(err=%e, "toyconsul: fail to decode register event"); + continue; + } + }; + self.inner.services.insert(service.name.clone(), service); + }, + other => { + tracing::warn!("toyconsul: unknown user event {}", other); } } } @@ -259,7 +315,10 @@ enum Commands { rpc_addr: std::path::PathBuf, }, /// List all services in the toyconsul - List, + List { + #[clap(short, long)] + rpc_addr: std::path::PathBuf, + }, } #[derive(clap::Parser)] @@ -323,8 +382,8 @@ async fn main() -> Result<()> { Commands::Start(args) => { handle_start_cmd(args).await?; } - Commands::List => { - handle_list_cmd().await?; + Commands::List { rpc_addr } => { + handle_list_cmd(rpc_addr).await?; } } @@ -479,8 +538,8 @@ async fn handle_start_cmd(args: StartArgs) -> Result<()> { Ok(()) } -async fn handle_list_cmd() -> Result<()> { - let conn = UnixStream::connect("/tmp/toyconsul.sock").await?; +async fn handle_list_cmd(rpc_addr: PathBuf) -> Result<()> { + let conn = UnixStream::connect(rpc_addr).await?; let data = serialize(&Op::List)?; let (reader, mut writer) = conn.into_split(); diff --git a/serf-core/src/serf/api.rs b/serf-core/src/serf/api.rs index 50487cd..966ec14 100644 --- a/serf-core/src/serf/api.rs +++ b/serf-core/src/serf/api.rs @@ -299,8 +299,7 @@ where } /// Used to broadcast a new query. The query must be fairly small, - /// and an error will be returned if the size limit is exceeded. This is only - /// available with protocol version 4 and newer. Query parameters are optional, + /// and an error will be returned if the size limit is exceeded. Query parameters are optional, /// and if not provided, a sane set of defaults will be used. pub async fn query( &self, diff --git a/serf/src/lib.rs b/serf/src/lib.rs index 8a1026e..68d37bb 100644 --- a/serf/src/lib.rs +++ b/serf/src/lib.rs @@ -16,17 +16,17 @@ pub use memberlist::net; #[cfg(feature = "quic")] pub use memberlist::quic; -/// [`Serf`](serf_core::Serf) for `tokio` runtime. +/// [`Serf`] for `tokio` runtime. #[cfg(feature = "tokio")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] pub mod tokio; -/// [`Serf`](serf_core::Serf) for `async-std` runtime. +/// [`Serf`] for `async-std` runtime. #[cfg(feature = "async-std")] #[cfg_attr(docsrs, doc(cfg(feature = "async-std")))] pub mod async_std; -/// [`Serf`](serf_core::Serf) for `smol` runtime. +/// [`Serf`] for `smol` runtime. #[cfg(feature = "smol")] #[cfg_attr(docsrs, doc(cfg(feature = "smol")))] pub mod smol; From ae77b1d1808876ac07db0caf14753a7f9f544308 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 22:08:32 +0800 Subject: [PATCH 32/39] Add CHANGELOG.md --- CHANGELOG.md | 26 +++++++++++++++++++ Cargo.toml | 2 +- examples/{toy-consul => toyconsul}/Cargo.toml | 0 examples/{toy-consul => toyconsul}/README.md | 2 +- .../{toy-consul => toyconsul}/src/main.rs | 0 5 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 CHANGELOG.md rename examples/{toy-consul => toyconsul}/Cargo.toml (100%) rename examples/{toy-consul => toyconsul}/README.md (96%) rename examples/{toy-consul => toyconsul}/src/main.rs (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..5bc891f --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,26 @@ +# Releases + +## 0.3.0 + +### Features + +- Redesign `Delegate` trait, making it easier to implement for users. +- Rewriting encoding/decoding to support forward and backward compitibility. +- Support `zstd`, `brotli`, `lz4`, and `snappy` for compressing. +- Support `crc32`, `xxhash64`, `xxhash32`, `xxhash3`, `murmur3` for checksuming. +- Unify returned error, all exported APIs return `Error` on `Result::Err`. + +### Example + +- Add [`toyconsul`](./examples/toyconsul/) Example + +### Breakage + +- Remove `native-tls` supports +- Remove `s2n-quic` supports +- Remove `TransformDelegate` trait to simplify `Delegate` trait +- Remove `JoinError`, add an new `Error::Multiple` variant + +### Testing + +- Add fuzzy testing for encoding/decoding diff --git a/Cargo.toml b/Cargo.toml index 0e3c6ee..dbbaf11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ members = [ "serf", "serf-core", - "examples/toy-consul", + "examples/toyconsul", ] resolver = "3" diff --git a/examples/toy-consul/Cargo.toml b/examples/toyconsul/Cargo.toml similarity index 100% rename from examples/toy-consul/Cargo.toml rename to examples/toyconsul/Cargo.toml diff --git a/examples/toy-consul/README.md b/examples/toyconsul/README.md similarity index 96% rename from examples/toy-consul/README.md rename to examples/toyconsul/README.md index 0ad2f35..ce0c4b9 100644 --- a/examples/toy-consul/README.md +++ b/examples/toyconsul/README.md @@ -1,6 +1,6 @@ # ToyConsul -A toy eventually consensus distributed database. +A toy eventually consensus distributed registry. ## Installation diff --git a/examples/toy-consul/src/main.rs b/examples/toyconsul/src/main.rs similarity index 100% rename from examples/toy-consul/src/main.rs rename to examples/toyconsul/src/main.rs From dc8bfde378fd22c62ec52c0c1c0035252965a5f9 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 22:47:51 +0800 Subject: [PATCH 33/39] Add fuzzy CI --- .github/workflows/ci.yml | 2 +- .github/workflows/fuzz.yml | 41 +++++++ .github/workflows/net.yml | 2 +- Cargo.toml | 1 + fuzz/.gitignore | 4 + fuzz/Cargo.toml | 22 ++++ fuzz/fuzz_targets/messages.rs | 15 +++ serf-core/src/types.rs | 11 +- serf-core/src/types/fuzzy.rs | 202 ++++++++++++++++++++++++++++++++++ serf-core/src/types/tests.rs | 172 +---------------------------- 10 files changed, 300 insertions(+), 172 deletions(-) create mode 100644 .github/workflows/fuzz.yml create mode 100644 fuzz/.gitignore create mode 100644 fuzz/Cargo.toml create mode 100644 fuzz/fuzz_targets/messages.rs create mode 100644 serf-core/src/types/fuzzy.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d36f4a2..35ab2c8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ on: - '**.txt' - 'art' workflow_dispatch: - schedule: [cron: "40 1 * * *"] + schedule: [cron: "0 1 */7 * *"] jobs: test: diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml new file mode 100644 index 0000000..71f5175 --- /dev/null +++ b/.github/workflows/fuzz.yml @@ -0,0 +1,41 @@ +name: Fuzz Testing +on: + push: + branches: + - main + paths-ignore: + - 'README.md' + - 'COPYRIGHT' + - 'LICENSE*' + - '**.md' + - '**.txt' + - 'art' + pull_request: + paths-ignore: + - 'README.md' + - 'COPYRIGHT' + - 'LICENSE*' + - '**.md' + - '**.txt' + - 'art' + schedule: [cron: "0 1 */7 * *"] + + +jobs: + fuzz: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + + - name: Install cargo-fuzz + run: cargo install cargo-fuzz + + - name: Run fuzzing + run: | + cargo fuzz build + cargo fuzz run messages -- -max_total_time=300 diff --git a/.github/workflows/net.yml b/.github/workflows/net.yml index c3ade14..c2ff157 100644 --- a/.github/workflows/net.yml +++ b/.github/workflows/net.yml @@ -20,7 +20,7 @@ on: - '**.txt' - 'art' workflow_dispatch: - schedule: [cron: "40 1 * * *"] + schedule: [cron: "0 1 */7 * *"] jobs: test: diff --git a/Cargo.toml b/Cargo.toml index dbbaf11..c844280 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "serf", "serf-core", "examples/toyconsul", + "fuzz", ] resolver = "3" diff --git a/fuzz/.gitignore b/fuzz/.gitignore new file mode 100644 index 0000000..1a45eee --- /dev/null +++ b/fuzz/.gitignore @@ -0,0 +1,4 @@ +target +corpus +artifacts +coverage diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 0000000..29c50df --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "serf-fuzz" +version = "0.0.0" +publish = false +edition = "2024" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +libfuzzer-sys = "0.4" + +[dependencies.serf-core] +path = "../serf-core" +features = ["arbitrary", "encryption"] + +[[bin]] +name = "messages" +path = "fuzz_targets/messages.rs" +test = false +doc = false +bench = false diff --git a/fuzz/fuzz_targets/messages.rs b/fuzz/fuzz_targets/messages.rs new file mode 100644 index 0000000..c5e4dc4 --- /dev/null +++ b/fuzz/fuzz_targets/messages.rs @@ -0,0 +1,15 @@ +#![no_main] +#![allow(clippy::type_complexity)] + +use libfuzzer_sys::fuzz_target; + +use serf_core::types::{ + Node, + fuzzy::{Message, encodable_round_trip}, +}; + +fuzz_target!( + |data: (Message>, Option>>)| { + assert!(encodable_round_trip(data.0, data.1)); + } +); diff --git a/serf-core/src/types.rs b/serf-core/src/types.rs index fd5686d..8ad419b 100644 --- a/serf-core/src/types.rs +++ b/serf-core/src/types.rs @@ -1,9 +1,9 @@ use std::time::Duration; pub use memberlist_core::proto::{ - DelegateVersion as MemberlistDelegateVersion, Domain, HostAddr, MaybeResolvedAddress, Node, - NodeId, ParseDomainError, ParseHostAddrError, ParseNodeIdError, - ProtocolVersion as MemberlistProtocolVersion, bytes, + Data, DataRef, DelegateVersion as MemberlistDelegateVersion, Domain, HostAddr, + MaybeResolvedAddress, Node, NodeId, ParseDomainError, ParseHostAddrError, ParseNodeIdError, + ProtocolVersion as MemberlistProtocolVersion, bytes, utils, }; pub use smol_str::*; @@ -54,6 +54,7 @@ mod arbitrary_impl; mod quickcheck_impl; #[cfg(test)] +#[cfg(feature = "quickcheck")] mod tests; mod clock; @@ -101,6 +102,10 @@ mod key; #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] pub use key::*; +#[cfg(any(feature = "arbitrary", feature = "quickcheck"))] +#[doc(hidden)] +pub mod fuzzy; + #[cfg(debug_assertions)] #[inline] fn debug_assert_write_eq(actual: usize, expected: usize) { diff --git a/serf-core/src/types/fuzzy.rs b/serf-core/src/types/fuzzy.rs new file mode 100644 index 0000000..5c69cc4 --- /dev/null +++ b/serf-core/src/types/fuzzy.rs @@ -0,0 +1,202 @@ +use super::*; + +use core::hash::Hash; + +use crate::types::bytes::Bytes; +use memberlist_core::proto::Data; + +/// Message for fuzzy testing +#[derive(Clone, Debug)] +pub enum Message { + /// Leave message + Leave(LeaveMessage), + /// Join message + Join(JoinMessage), + /// PushPull message + PushPull(PushPullMessage), + /// UserEvent message + UserEvent(UserEventMessage), + /// Query message + Query(QueryMessage), + /// QueryResponse message + QueryResponse(QueryResponseMessage), + /// ConflictResponse message + ConflictResponse(ConflictResponseMessage), + #[cfg(feature = "encryption")] + /// KeyRequest message + KeyRequest(KeyRequestMessage), + #[cfg(feature = "encryption")] + /// KeyResponse message + KeyResponse(KeyResponseMessage), +} + +#[cfg(feature = "arbitrary")] +const _: () = { + use arbitrary::{Arbitrary, Unstructured}; + + impl<'a, I, A> Arbitrary<'a> for Message + where + I: Arbitrary<'a> + Hash + Eq, + A: Arbitrary<'a>, + { + fn arbitrary(g: &mut Unstructured<'a>) -> arbitrary::Result { + loop { + let variant = MessageType::arbitrary(g)?; + + return Ok(match variant { + MessageType::ConflictResponse => Message::ConflictResponse(Arbitrary::arbitrary(g)?), + MessageType::Join => Message::Join(Arbitrary::arbitrary(g)?), + MessageType::Leave => Message::Leave(Arbitrary::arbitrary(g)?), + MessageType::PushPull => Message::PushPull(Arbitrary::arbitrary(g)?), + MessageType::Query => Message::Query(Arbitrary::arbitrary(g)?), + MessageType::QueryResponse => Message::QueryResponse(Arbitrary::arbitrary(g)?), + MessageType::UserEvent => Message::UserEvent(Arbitrary::arbitrary(g)?), + #[cfg(feature = "encryption")] + MessageType::KeyRequest => Message::KeyRequest(Arbitrary::arbitrary(g)?), + #[cfg(feature = "encryption")] + MessageType::KeyResponse => Message::KeyResponse(Arbitrary::arbitrary(g)?), + _ => continue, + }); + } + } + } +}; + +#[cfg(feature = "quickcheck")] +const _: () = { + use quickcheck::{Arbitrary, Gen}; + + impl Arbitrary for Message + where + I: Arbitrary + Hash + Eq, + A: Arbitrary, + { + fn arbitrary(g: &mut Gen) -> Self { + loop { + let variant = MessageType::arbitrary(g); + + return match variant { + MessageType::ConflictResponse => Message::ConflictResponse(Arbitrary::arbitrary(g)), + MessageType::Join => Message::Join(Arbitrary::arbitrary(g)), + MessageType::Leave => Message::Leave(Arbitrary::arbitrary(g)), + MessageType::PushPull => Message::PushPull(Arbitrary::arbitrary(g)), + MessageType::Query => Message::Query(Arbitrary::arbitrary(g)), + MessageType::QueryResponse => Message::QueryResponse(Arbitrary::arbitrary(g)), + MessageType::UserEvent => Message::UserEvent(Arbitrary::arbitrary(g)), + #[cfg(feature = "encryption")] + MessageType::KeyRequest => Message::KeyRequest(Arbitrary::arbitrary(g)), + #[cfg(feature = "encryption")] + MessageType::KeyResponse => Message::KeyResponse(Arbitrary::arbitrary(g)), + _ => continue, + }; + } + } + } +}; + +fn encode(data: &T) -> Bytes { + encode_message_to_bytes(data).unwrap() +} + +fn encode_relay(data: &T, node: &Node) -> Bytes +where + I: Data, + A: Data, + T: Encodable, +{ + encode_relay_message_to_bytes(data, node).unwrap() +} + +pub fn encodable_round_trip(msg: Message, node: Option>) -> bool +where + I: Data + Eq + Hash, + A: Data + PartialEq, +{ + macro_rules! encode_variant { + (< $($g:ty), +$(,)? > $variant:ident ($ty:ty) <- $input:ident) => {{ + let data = encode(&$input); + assert_eq!(data.len(), encoded_message_len(&$input), "message: length mismatch"); + let decoded = super::decode_message :: < $($g),* > (&data).unwrap(); + let MessageRef::$variant(decoded) = decoded else { return false }; + + let owned = <$ty as Data>::from_ref(decoded).unwrap(); + assert_eq!($input, owned, "message: decoded mismatch"); + true + }}; + (@relay< $($g:ty), +$(,)? > $variant:ident ($ty:ty) <- ($input:ident, $node:ident)) => {{ + let data = encode_relay(&$input, &$node); + assert_eq!(data.len(), encoded_relay_message_len(&$input, &$node), "relay message: length mismatch"); + let decoded = super::decode_message :: < $($g),* > (&data).unwrap(); + let MessageRef::Relay(RelayMessageRef { node, payload, .. }) = decoded else { return false }; + assert_eq!( as Data>::from_ref(node).unwrap(), $node, "relay message: node mismatch"); + + let decoded = super::decode_message :: < $($g),* > (&payload).unwrap(); + let MessageRef::$variant(decoded) = decoded else { return false }; + + let owned = <$ty as Data>::from_ref(decoded).unwrap(); + assert_eq!($input, owned, "relay message: decoded mismatch"); + true + }}; + } + + match node { + Some(node) => match msg { + Message::Leave(leave_message) => { + encode_variant!(@relay Leave(LeaveMessage) <- (leave_message, node)) + } + Message::Join(join_message) => { + encode_variant!(@relay Join(JoinMessage) <- (join_message, node)) + } + Message::PushPull(push_pull_message) => { + encode_variant!(@relay PushPull(PushPullMessage) <- (push_pull_message, node)) + } + Message::UserEvent(user_event_message) => { + encode_variant!(@relay UserEvent(UserEventMessage) <- (user_event_message, node)) + } + Message::Query(query_message) => { + encode_variant!(@relay Query(QueryMessage) <- (query_message, node)) + } + Message::QueryResponse(query_response_message) => { + encode_variant!(@relay QueryResponse(QueryResponseMessage) <- (query_response_message, node)) + } + Message::ConflictResponse(conflict_response_message) => { + encode_variant!(@relay ConflictResponse(ConflictResponseMessage) <- (conflict_response_message, node)) + } + #[cfg(feature = "encryption")] + Message::KeyRequest(key_request_message) => { + encode_variant!(@relay KeyRequest(KeyRequestMessage) <- (key_request_message, node)) + } + #[cfg(feature = "encryption")] + Message::KeyResponse(key_response_message) => { + encode_variant!(@relay KeyResponse(KeyResponseMessage) <- (key_response_message, node)) + } + }, + None => match msg { + Message::Leave(msg) => encode_variant!( Leave(LeaveMessage) <- msg), + Message::Join(join_message) => encode_variant!( Join(JoinMessage) <- join_message), + Message::PushPull(push_pull_message) => { + encode_variant!( PushPull(PushPullMessage) <- push_pull_message) + } + Message::UserEvent(user_event_message) => { + encode_variant!( UserEvent(UserEventMessage) <- user_event_message) + } + Message::Query(query_message) => { + encode_variant!( Query(QueryMessage) <- query_message) + } + Message::QueryResponse(query_response_message) => { + encode_variant!( QueryResponse(QueryResponseMessage) <- query_response_message) + } + Message::ConflictResponse(conflict_response_message) => { + encode_variant!( ConflictResponse(ConflictResponseMessage) <- conflict_response_message) + } + #[cfg(feature = "encryption")] + Message::KeyRequest(key_request_message) => { + encode_variant!( KeyRequest(KeyRequestMessage) <- key_request_message) + } + #[cfg(feature = "encryption")] + Message::KeyResponse(key_response_message) => { + encode_variant!( KeyResponse(KeyResponseMessage) <- key_response_message) + } + }, + } +} diff --git a/serf-core/src/types/tests.rs b/serf-core/src/types/tests.rs index 25253fe..f630a9d 100644 --- a/serf-core/src/types/tests.rs +++ b/serf-core/src/types/tests.rs @@ -1,12 +1,10 @@ -use std::hash::Hash; +use memberlist_core::proto::{Data, DataRef}; -use memberlist_core::{ - bytes::Bytes, - proto::{Data, DataRef}, +use super::{ + coordinate::Coordinate, + fuzzy::{Message, encodable_round_trip}, + *, }; -use quickcheck::{Arbitrary, Gen}; - -use super::{coordinate::Coordinate, *}; fn data_round_trip(data: &T) { let mut buf = vec![0; data.encoded_len() + 2]; @@ -112,166 +110,6 @@ data_round_trip! { KeyResponseMessage, } -#[derive(Clone, Debug)] -enum Message { - /// Leave message - Leave(LeaveMessage), - /// Join message - Join(JoinMessage), - /// PushPull message - PushPull(PushPullMessage), - /// UserEvent message - UserEvent(UserEventMessage), - /// Query message - Query(QueryMessage), - /// QueryResponse message - QueryResponse(QueryResponseMessage), - /// ConflictResponse message - ConflictResponse(ConflictResponseMessage), - #[cfg(feature = "encryption")] - /// KeyRequest message - KeyRequest(KeyRequestMessage), - #[cfg(feature = "encryption")] - /// KeyResponse message - KeyResponse(KeyResponseMessage), -} - -impl Arbitrary for Message -where - I: Arbitrary + Hash + Eq, - A: Arbitrary, -{ - fn arbitrary(g: &mut Gen) -> Self { - loop { - let variant = MessageType::arbitrary(g); - - return match variant { - MessageType::ConflictResponse => { - Message::ConflictResponse(ConflictResponseMessage::arbitrary(g)) - } - MessageType::Join => Message::Join(JoinMessage::arbitrary(g)), - MessageType::Leave => Message::Leave(LeaveMessage::arbitrary(g)), - MessageType::PushPull => Message::PushPull(PushPullMessage::arbitrary(g)), - MessageType::Query => Message::Query(QueryMessage::arbitrary(g)), - MessageType::QueryResponse => Message::QueryResponse(QueryResponseMessage::arbitrary(g)), - MessageType::UserEvent => Message::UserEvent(UserEventMessage::arbitrary(g)), - #[cfg(feature = "encryption")] - MessageType::KeyRequest => Message::KeyRequest(KeyRequestMessage::arbitrary(g)), - #[cfg(feature = "encryption")] - MessageType::KeyResponse => Message::KeyResponse(KeyResponseMessage::arbitrary(g)), - _ => continue, - }; - } - } -} - -fn encode(data: &T) -> Bytes { - encode_message_to_bytes(data).unwrap() -} - -fn encode_relay(data: &T, node: &Node) -> Bytes -where - I: Data, - A: Data, - T: Encodable, -{ - encode_relay_message_to_bytes(data, node).unwrap() -} - -fn encodable_round_trip(msg: Message, node: Option>) -> bool -where - I: Data + Eq + Hash, - A: Data + PartialEq, -{ - macro_rules! encode_variant { - (< $($g:ty), +$(,)? > $variant:ident ($ty:ty) <- $input:ident) => {{ - let data = encode(&$input); - assert_eq!(data.len(), encoded_message_len(&$input), "message: length mismatch"); - let decoded = super::decode_message :: < $($g),* > (&data).unwrap(); - let MessageRef::$variant(decoded) = decoded else { return false }; - - let owned = <$ty as Data>::from_ref(decoded).unwrap(); - assert_eq!($input, owned, "message: decoded mismatch"); - true - }}; - (@relay< $($g:ty), +$(,)? > $variant:ident ($ty:ty) <- ($input:ident, $node:ident)) => {{ - let data = encode_relay(&$input, &$node); - assert_eq!(data.len(), encoded_relay_message_len(&$input, &$node), "relay message: length mismatch"); - let decoded = super::decode_message :: < $($g),* > (&data).unwrap(); - let MessageRef::Relay(RelayMessageRef { node, payload, .. }) = decoded else { return false }; - assert_eq!( as Data>::from_ref(node).unwrap(), $node, "relay message: node mismatch"); - - let decoded = super::decode_message :: < $($g),* > (&payload).unwrap(); - let MessageRef::$variant(decoded) = decoded else { return false }; - - let owned = <$ty as Data>::from_ref(decoded).unwrap(); - assert_eq!($input, owned, "relay message: decoded mismatch"); - true - }}; - } - - match node { - Some(node) => match msg { - Message::Leave(leave_message) => { - encode_variant!(@relay Leave(LeaveMessage) <- (leave_message, node)) - } - Message::Join(join_message) => { - encode_variant!(@relay Join(JoinMessage) <- (join_message, node)) - } - Message::PushPull(push_pull_message) => { - encode_variant!(@relay PushPull(PushPullMessage) <- (push_pull_message, node)) - } - Message::UserEvent(user_event_message) => { - encode_variant!(@relay UserEvent(UserEventMessage) <- (user_event_message, node)) - } - Message::Query(query_message) => { - encode_variant!(@relay Query(QueryMessage) <- (query_message, node)) - } - Message::QueryResponse(query_response_message) => { - encode_variant!(@relay QueryResponse(QueryResponseMessage) <- (query_response_message, node)) - } - Message::ConflictResponse(conflict_response_message) => { - encode_variant!(@relay ConflictResponse(ConflictResponseMessage) <- (conflict_response_message, node)) - } - #[cfg(feature = "encryption")] - Message::KeyRequest(key_request_message) => { - encode_variant!(@relay KeyRequest(KeyRequestMessage) <- (key_request_message, node)) - } - #[cfg(feature = "encryption")] - Message::KeyResponse(key_response_message) => { - encode_variant!(@relay KeyResponse(KeyResponseMessage) <- (key_response_message, node)) - } - }, - None => match msg { - Message::Leave(msg) => encode_variant!( Leave(LeaveMessage) <- msg), - Message::Join(join_message) => encode_variant!( Join(JoinMessage) <- join_message), - Message::PushPull(push_pull_message) => { - encode_variant!( PushPull(PushPullMessage) <- push_pull_message) - } - Message::UserEvent(user_event_message) => { - encode_variant!( UserEvent(UserEventMessage) <- user_event_message) - } - Message::Query(query_message) => { - encode_variant!( Query(QueryMessage) <- query_message) - } - Message::QueryResponse(query_response_message) => { - encode_variant!( QueryResponse(QueryResponseMessage) <- query_response_message) - } - Message::ConflictResponse(conflict_response_message) => { - encode_variant!( ConflictResponse(ConflictResponseMessage) <- conflict_response_message) - } - #[cfg(feature = "encryption")] - Message::KeyRequest(key_request_message) => { - encode_variant!( KeyRequest(KeyRequestMessage) <- key_request_message) - } - #[cfg(feature = "encryption")] - Message::KeyResponse(key_response_message) => { - encode_variant!( KeyResponse(KeyResponseMessage) <- key_response_message) - } - }, - } -} - macro_rules! encodable_round_trip { (@message $(<$a:ty, $b:ty>),+$(,)?) => { $( From e6483f0671ea8ec5f1f48e9b6efcb25e42d44307 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 22:51:08 +0800 Subject: [PATCH 34/39] cleanup --- CHANGELOG.md | 2 +- README.md | 2 +- examples/toyconsul/README.md | 2 +- serf-core/README.md | 1 + serf-core/src/lib.rs | 2 +- serf/README.md | 1 + 6 files changed, 6 insertions(+), 4 deletions(-) create mode 120000 serf-core/README.md create mode 120000 serf/README.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bc891f..4b20ca9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ ### Example -- Add [`toyconsul`](./examples/toyconsul/) Example +- Add [toyconsul](https://github.com/al8n/serf/tree/main/examples/), a toy eventually consistent distributed registry. ### Breakage diff --git a/README.md b/README.md index 66f2e39..2d24ab2 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ serf is WASM/WASI friendly, all crates can be compiled to `wasm-wasi` and `wasm- ## Examples -See [examples/toyconsul](./examples/toy-consul/). +See [examples/toyconsul](https://github.com/al8n/serf/tree/main/examples/toyconsul). ## Protocol diff --git a/examples/toyconsul/README.md b/examples/toyconsul/README.md index ce0c4b9..baf260e 100644 --- a/examples/toyconsul/README.md +++ b/examples/toyconsul/README.md @@ -1,6 +1,6 @@ # ToyConsul -A toy eventually consensus distributed registry. +A toy eventually consistent distributed registry. ## Installation diff --git a/serf-core/README.md b/serf-core/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/serf-core/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/serf-core/src/lib.rs b/serf-core/src/lib.rs index 092d41c..7a51d92 100644 --- a/serf-core/src/lib.rs +++ b/serf-core/src/lib.rs @@ -1,4 +1,4 @@ -#![doc = include_str!("../../README.md")] +#![doc = include_str!("../README.md")] #![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] #![forbid(unsafe_code)] // #![deny(warnings, missing_docs)] diff --git a/serf/README.md b/serf/README.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/serf/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file From 52e4faec9c139e055d1fca849089dd0cb49681dc Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 22:51:11 +0800 Subject: [PATCH 35/39] Update lib.rs --- serf/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serf/src/lib.rs b/serf/src/lib.rs index 68d37bb..b292918 100644 --- a/serf/src/lib.rs +++ b/serf/src/lib.rs @@ -1,4 +1,4 @@ -#![doc = include_str!("../../README.md")] +#![doc = include_str!("../README.md")] #![doc(html_logo_url = "https://raw.githubusercontent.com/al8n/serf/main/art/logo_72x72.png")] #![forbid(unsafe_code)] #![deny(warnings, missing_docs)] From 1be51560ca93724c814823d09db03af2b40cb69a Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 23:10:58 +0800 Subject: [PATCH 36/39] Update Cargo.toml --- fuzz/Cargo.toml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 29c50df..45f4a37 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "serf-fuzz" +name = "serf-types-fuzz" version = "0.0.0" publish = false edition = "2024" @@ -10,9 +10,7 @@ cargo-fuzz = true [dependencies] libfuzzer-sys = "0.4" -[dependencies.serf-core] -path = "../serf-core" -features = ["arbitrary", "encryption"] +serf-core = { workspace = true, features = ["arbitrary", "default", "encryption", "test"] } [[bin]] name = "messages" From 7ceccc08ee3cc2614f03af68daad7ea9d1c77c91 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 23:12:49 +0800 Subject: [PATCH 37/39] Update .codecov.yml --- .codecov.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 4aff77d..cda5191 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -2,9 +2,11 @@ codecov: require_ci_to_pass: false ignore: - - core/src/serf/base/tests - - core/src/serf/base/tests.rs + - serf-core/src/serf/base/tests + - serf-core/src/serf/base/tests.rs - serf/test + - fuzz/ + - examples/ coverage: status: From 57a92ed9a74ec24cf29855c3d7d59249c8913e97 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 23:29:45 +0800 Subject: [PATCH 38/39] Fix fuzzy --- .github/workflows/fuzz.yml | 2 +- serf-core/src/types/fuzzy.rs | 34 ++++++++++++++++------------------ 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index 71f5175..62f4ff6 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -38,4 +38,4 @@ jobs: - name: Run fuzzing run: | cargo fuzz build - cargo fuzz run messages -- -max_total_time=300 + cargo fuzz run messages -- -max_len=4096 -max_total_time=300 diff --git a/serf-core/src/types/fuzzy.rs b/serf-core/src/types/fuzzy.rs index 5c69cc4..d1399b1 100644 --- a/serf-core/src/types/fuzzy.rs +++ b/serf-core/src/types/fuzzy.rs @@ -40,24 +40,22 @@ const _: () = { A: Arbitrary<'a>, { fn arbitrary(g: &mut Unstructured<'a>) -> arbitrary::Result { - loop { - let variant = MessageType::arbitrary(g)?; - - return Ok(match variant { - MessageType::ConflictResponse => Message::ConflictResponse(Arbitrary::arbitrary(g)?), - MessageType::Join => Message::Join(Arbitrary::arbitrary(g)?), - MessageType::Leave => Message::Leave(Arbitrary::arbitrary(g)?), - MessageType::PushPull => Message::PushPull(Arbitrary::arbitrary(g)?), - MessageType::Query => Message::Query(Arbitrary::arbitrary(g)?), - MessageType::QueryResponse => Message::QueryResponse(Arbitrary::arbitrary(g)?), - MessageType::UserEvent => Message::UserEvent(Arbitrary::arbitrary(g)?), - #[cfg(feature = "encryption")] - MessageType::KeyRequest => Message::KeyRequest(Arbitrary::arbitrary(g)?), - #[cfg(feature = "encryption")] - MessageType::KeyResponse => Message::KeyResponse(Arbitrary::arbitrary(g)?), - _ => continue, - }); - } + let variant = MessageType::arbitrary(g)?; + + Ok(match variant { + MessageType::ConflictResponse => Message::ConflictResponse(Arbitrary::arbitrary(g)?), + MessageType::Join => Message::Join(Arbitrary::arbitrary(g)?), + MessageType::Leave => Message::Leave(Arbitrary::arbitrary(g)?), + MessageType::PushPull => Message::PushPull(Arbitrary::arbitrary(g)?), + MessageType::Query => Message::Query(Arbitrary::arbitrary(g)?), + MessageType::QueryResponse => Message::QueryResponse(Arbitrary::arbitrary(g)?), + MessageType::UserEvent => Message::UserEvent(Arbitrary::arbitrary(g)?), + #[cfg(feature = "encryption")] + MessageType::KeyRequest => Message::KeyRequest(Arbitrary::arbitrary(g)?), + #[cfg(feature = "encryption")] + MessageType::KeyResponse => Message::KeyResponse(Arbitrary::arbitrary(g)?), + _ => Message::Query(QueryMessage::arbitrary(g)?), + }) } } }; From 85d15fe94a003cc34f39340147684c2cbe23f070 Mon Sep 17 00:00:00 2001 From: al8n Date: Tue, 4 Mar 2025 23:34:15 +0800 Subject: [PATCH 39/39] Add more unit tests --- serf-core/src/coalesce/member.rs | 561 +++++++++++++++---------------- serf-core/src/coalesce/user.rs | 273 ++++++++------- 2 files changed, 413 insertions(+), 421 deletions(-) diff --git a/serf-core/src/coalesce/member.rs b/serf-core/src/coalesce/member.rs index 2a29c4d..9ddbc00 100644 --- a/serf-core/src/coalesce/member.rs +++ b/serf-core/src/coalesce/member.rs @@ -112,286 +112,281 @@ where } } -// #[cfg(all(test, feature = "test"))] -// #[allow(clippy::collapsible_match)] -// mod tests { -// use std::{net::SocketAddr, time::Duration}; - -// use futures::FutureExt; -// use memberlist_core::{ -// agnostic_lite::{RuntimeLite, tokio::TokioRuntime}, -// transport::resolver::socket_addr::SocketAddrResolver, -// }; -// use crate::types::{MemberStatus, UserEventMessage}; -// use smol_str::SmolStr; - -// use crate::{ -// DefaultDelegate, -// coalesce::coalesced_event, -// event::{CrateEventType, MemberEvent}, -// }; - -// use super::*; - -// type Transport = UnimplementedTransport< -// SmolStr, -// SocketAddrResolver, -// -// TokioRuntime, -// >; - -// type Delegate = DefaultDelegate; - -// #[tokio::test] -// async fn test_member_event_coealesce_basic() { -// let (tx, rx) = async_channel::unbounded(); -// let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); -// let coalescer = MemberEventCoalescer::::new(); - -// let in_ = coalesced_event( -// tx, -// shutdown_rx, -// Duration::from_millis(20), -// Duration::from_millis(20), -// coalescer, -// ); - -// let send = vec![ -// MemberEvent { -// ty: MemberEventType::Join, -// members: TinyVec::from(Member::new( -// Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), -// Default::default(), -// MemberStatus::None, -// )) -// .into(), -// }, -// MemberEvent { -// ty: MemberEventType::Leave, -// members: TinyVec::from(Member::new( -// Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), -// Default::default(), -// MemberStatus::None, -// )) -// .into(), -// }, -// MemberEvent { -// ty: MemberEventType::Leave, -// members: TinyVec::from(Member::new( -// Node::new("bar".into(), "127.0.0.1:8080".parse().unwrap()), -// Default::default(), -// MemberStatus::None, -// )) -// .into(), -// }, -// MemberEvent { -// ty: MemberEventType::Update, -// members: TinyVec::from(Member::new( -// Node::new("zip".into(), "127.0.0.1:8080".parse().unwrap()), -// [("role", "foo")].into_iter().collect(), -// MemberStatus::None, -// )) -// .into(), -// }, -// MemberEvent { -// ty: MemberEventType::Update, -// members: TinyVec::from(Member::new( -// Node::new("zip".into(), "127.0.0.1:8080".parse().unwrap()), -// [("role", "bar")].into_iter().collect(), -// MemberStatus::None, -// )) -// .into(), -// }, -// MemberEvent { -// ty: MemberEventType::Reap, -// members: TinyVec::from(Member::new( -// Node::new("dead".into(), "127.0.0.1:8080".parse().unwrap()), -// Default::default(), -// MemberStatus::None, -// )) -// .into(), -// }, -// ]; - -// for event in send { -// in_.send(CrateEvent::from(event)).await.unwrap(); -// } - -// let mut events = HashMap::new(); -// let timeout = TokioRuntime::sleep(Duration::from_millis(40)); -// futures::pin_mut!(timeout); -// loop { -// futures::select! { -// e = rx.recv().fuse() => { -// let e = e.unwrap(); -// events.insert(e.ty(), e.clone()); -// } -// _ = (&mut timeout).fuse() => { -// break; -// }, -// } -// } - -// assert_eq!(events.len(), 3); - -// match events.get(&CrateEventType::Member(MemberEventType::Leave)) { -// None => panic!(""), -// Some(e) => match e { -// CrateEvent::Member(MemberEvent { members, .. }) => { -// assert_eq!(members.len(), 2); - -// let expected = ["bar", "foo"]; -// let mut names = [members[0].node.id().clone(), members[1].node.id().clone()]; -// names.sort(); - -// assert_eq!(names, expected); -// } -// _ => panic!(""), -// }, -// } - -// match events.get(&CrateEventType::Member(MemberEventType::Update)) { -// None => panic!(""), -// Some(e) => match e { -// CrateEvent::Member(MemberEvent { members, .. }) => { -// assert_eq!(members.len(), 1); -// assert_eq!(members[0].node.id(), "zip"); -// assert_eq!(members[0].tags().get("role").unwrap(), "bar"); -// } -// _ => panic!(""), -// }, -// } - -// match events.get(&CrateEventType::Member(MemberEventType::Reap)) { -// None => panic!(""), -// Some(e) => match e { -// CrateEvent::Member(MemberEvent { members, .. }) => { -// assert_eq!(members.len(), 1); -// assert_eq!(members[0].node.id(), "dead"); -// } -// _ => panic!(""), -// }, -// } -// } - -// #[tokio::test] -// async fn test_member_event_coalesce_tag_update() { -// let (tx, rx) = async_channel::unbounded(); -// let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); -// let coalescer = MemberEventCoalescer::::new(); - -// let in_ = coalesced_event( -// tx, -// shutdown_rx, -// Duration::from_millis(5), -// Duration::from_millis(5), -// coalescer, -// ); - -// in_ -// .send(CrateEvent::from(MemberEvent { -// ty: MemberEventType::Update, -// members: TinyVec::from(Member::new( -// Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), -// [("role", "foo")].into_iter().collect(), -// MemberStatus::None, -// )) -// .into(), -// })) -// .await -// .unwrap(); - -// TokioRuntime::sleep(Duration::from_millis(30)).await; - -// futures::select! { -// e = rx.recv().fuse() => { -// let e = e.unwrap(); - -// match e { -// CrateEvent::Member(MemberEvent { ty, .. }) => { -// assert!(matches!(ty, MemberEventType::Update)); -// } -// _ => panic!("expected update"), -// } -// } -// default => panic!("expected update"), -// } - -// // Second update should not be suppressed even though -// // last event was an update -// in_ -// .send(CrateEvent::from(MemberEvent { -// ty: MemberEventType::Update, -// members: TinyVec::from(Member::new( -// Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), -// [("role", "bar")].into_iter().collect(), -// MemberStatus::None, -// )) -// .into(), -// })) -// .await -// .unwrap(); -// TokioRuntime::sleep(Duration::from_millis(10)).await; - -// futures::select! { -// e = rx.recv().fuse() => { -// let e = e.unwrap(); - -// match e { -// CrateEvent::Member(MemberEvent { ty, .. }) => { -// assert!(matches!(ty, MemberEventType::Update)); -// } -// _ => panic!("expected update"), -// } -// } -// default => panic!("expected update"), -// } -// } - -// #[test] -// fn test_member_event_coalesce_pass_through() { -// let cases = [ -// (CrateEvent::from(UserEventMessage::default()), false), -// ( -// CrateEvent::from(MemberEvent { -// ty: MemberEventType::Join, -// members: TinyVec::new().into(), -// }), -// true, -// ), -// ( -// CrateEvent::from(MemberEvent { -// ty: MemberEventType::Leave, -// members: TinyVec::new().into(), -// }), -// true, -// ), -// ( -// CrateEvent::from(MemberEvent { -// ty: MemberEventType::Failed, -// members: TinyVec::new().into(), -// }), -// true, -// ), -// ( -// CrateEvent::from(MemberEvent { -// ty: MemberEventType::Update, -// members: TinyVec::new().into(), -// }), -// true, -// ), -// ( -// CrateEvent::from(MemberEvent { -// ty: MemberEventType::Reap, -// members: TinyVec::new().into(), -// }), -// true, -// ), -// ]; - -// for (event, handle) in cases.iter() { -// let coalescer = MemberEventCoalescer::::new(); -// assert_eq!(coalescer.handle(event), *handle); -// } -// } -// } +#[cfg(all(test, feature = "test"))] +#[allow(clippy::collapsible_match)] +mod tests { + use std::{net::SocketAddr, time::Duration}; + + use crate::types::{MemberStatus, UserEventMessage}; + use futures::FutureExt; + use memberlist_core::{ + agnostic_lite::{RuntimeLite, tokio::TokioRuntime}, + transport::{resolver::socket_addr::SocketAddrResolver, unimplemented::UnimplementedTransport}, + }; + use smol_str::SmolStr; + + use crate::{ + DefaultDelegate, + coalesce::coalesced_event, + event::{CrateEventType, MemberEvent}, + }; + + use super::*; + + type Transport = UnimplementedTransport, TokioRuntime>; + + type Delegate = DefaultDelegate; + + #[tokio::test] + async fn test_member_event_coealesce_basic() { + let (tx, rx) = async_channel::unbounded(); + let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); + let coalescer = MemberEventCoalescer::::new(); + + let in_ = coalesced_event( + tx, + shutdown_rx, + Duration::from_millis(20), + Duration::from_millis(20), + coalescer, + ); + + let send = vec![ + MemberEvent { + ty: MemberEventType::Join, + members: TinyVec::from(Member::new( + Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), + Default::default(), + MemberStatus::None, + )) + .into(), + }, + MemberEvent { + ty: MemberEventType::Leave, + members: TinyVec::from(Member::new( + Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), + Default::default(), + MemberStatus::None, + )) + .into(), + }, + MemberEvent { + ty: MemberEventType::Leave, + members: TinyVec::from(Member::new( + Node::new("bar".into(), "127.0.0.1:8080".parse().unwrap()), + Default::default(), + MemberStatus::None, + )) + .into(), + }, + MemberEvent { + ty: MemberEventType::Update, + members: TinyVec::from(Member::new( + Node::new("zip".into(), "127.0.0.1:8080".parse().unwrap()), + [("role", "foo")].into_iter().collect(), + MemberStatus::None, + )) + .into(), + }, + MemberEvent { + ty: MemberEventType::Update, + members: TinyVec::from(Member::new( + Node::new("zip".into(), "127.0.0.1:8080".parse().unwrap()), + [("role", "bar")].into_iter().collect(), + MemberStatus::None, + )) + .into(), + }, + MemberEvent { + ty: MemberEventType::Reap, + members: TinyVec::from(Member::new( + Node::new("dead".into(), "127.0.0.1:8080".parse().unwrap()), + Default::default(), + MemberStatus::None, + )) + .into(), + }, + ]; + + for event in send { + in_.send(CrateEvent::from(event)).await.unwrap(); + } + + let mut events = HashMap::new(); + let timeout = TokioRuntime::sleep(Duration::from_millis(40)); + futures::pin_mut!(timeout); + loop { + futures::select! { + e = rx.recv().fuse() => { + let e = e.unwrap(); + events.insert(e.ty(), e.clone()); + } + _ = (&mut timeout).fuse() => { + break; + }, + } + } + + assert_eq!(events.len(), 3); + + match events.get(&CrateEventType::Member(MemberEventType::Leave)) { + None => panic!(""), + Some(e) => match e { + CrateEvent::Member(MemberEvent { members, .. }) => { + assert_eq!(members.len(), 2); + + let expected = ["bar", "foo"]; + let mut names = [members[0].node.id().clone(), members[1].node.id().clone()]; + names.sort(); + + assert_eq!(names, expected); + } + _ => panic!(""), + }, + } + + match events.get(&CrateEventType::Member(MemberEventType::Update)) { + None => panic!(""), + Some(e) => match e { + CrateEvent::Member(MemberEvent { members, .. }) => { + assert_eq!(members.len(), 1); + assert_eq!(members[0].node.id(), "zip"); + assert_eq!(members[0].tags().get("role").unwrap(), "bar"); + } + _ => panic!(""), + }, + } + + match events.get(&CrateEventType::Member(MemberEventType::Reap)) { + None => panic!(""), + Some(e) => match e { + CrateEvent::Member(MemberEvent { members, .. }) => { + assert_eq!(members.len(), 1); + assert_eq!(members[0].node.id(), "dead"); + } + _ => panic!(""), + }, + } + } + + #[tokio::test] + async fn test_member_event_coalesce_tag_update() { + let (tx, rx) = async_channel::unbounded(); + let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); + let coalescer = MemberEventCoalescer::::new(); + + let in_ = coalesced_event( + tx, + shutdown_rx, + Duration::from_millis(5), + Duration::from_millis(5), + coalescer, + ); + + in_ + .send(CrateEvent::from(MemberEvent { + ty: MemberEventType::Update, + members: TinyVec::from(Member::new( + Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), + [("role", "foo")].into_iter().collect(), + MemberStatus::None, + )) + .into(), + })) + .await + .unwrap(); + + TokioRuntime::sleep(Duration::from_millis(30)).await; + + futures::select! { + e = rx.recv().fuse() => { + let e = e.unwrap(); + + match e { + CrateEvent::Member(MemberEvent { ty, .. }) => { + assert!(matches!(ty, MemberEventType::Update)); + } + _ => panic!("expected update"), + } + } + default => panic!("expected update"), + } + + // Second update should not be suppressed even though + // last event was an update + in_ + .send(CrateEvent::from(MemberEvent { + ty: MemberEventType::Update, + members: TinyVec::from(Member::new( + Node::new("foo".into(), "127.0.0.1:8080".parse().unwrap()), + [("role", "bar")].into_iter().collect(), + MemberStatus::None, + )) + .into(), + })) + .await + .unwrap(); + TokioRuntime::sleep(Duration::from_millis(10)).await; + + futures::select! { + e = rx.recv().fuse() => { + let e = e.unwrap(); + + match e { + CrateEvent::Member(MemberEvent { ty, .. }) => { + assert!(matches!(ty, MemberEventType::Update)); + } + _ => panic!("expected update"), + } + } + default => panic!("expected update"), + } + } + + #[test] + fn test_member_event_coalesce_pass_through() { + let cases = [ + (CrateEvent::from(UserEventMessage::default()), false), + ( + CrateEvent::from(MemberEvent { + ty: MemberEventType::Join, + members: TinyVec::new().into(), + }), + true, + ), + ( + CrateEvent::from(MemberEvent { + ty: MemberEventType::Leave, + members: TinyVec::new().into(), + }), + true, + ), + ( + CrateEvent::from(MemberEvent { + ty: MemberEventType::Failed, + members: TinyVec::new().into(), + }), + true, + ), + ( + CrateEvent::from(MemberEvent { + ty: MemberEventType::Update, + members: TinyVec::new().into(), + }), + true, + ), + ( + CrateEvent::from(MemberEvent { + ty: MemberEventType::Reap, + members: TinyVec::new().into(), + }), + true, + ), + ]; + + for (event, handle) in cases.iter() { + let coalescer = MemberEventCoalescer::::new(); + assert_eq!(coalescer.handle(event), *handle); + } + } +} diff --git a/serf-core/src/coalesce/user.rs b/serf-core/src/coalesce/user.rs index e5f868a..152d66e 100644 --- a/serf-core/src/coalesce/user.rs +++ b/serf-core/src/coalesce/user.rs @@ -97,141 +97,138 @@ where } } -// #[cfg(all(test, feature = "test"))] -// mod tests { -// use std::net::SocketAddr; - -// use agnostic_lite::tokio::TokioRuntime; -// use memberlist_core::transport::resolver::socket_addr::SocketAddrResolver; - -// use crate::{ -// DefaultDelegate, -// event::{MemberEvent, MemberEventType}, -// }; - -// use super::*; - -// type Transport = UnimplementedTransport< -// SmolStr, -// SocketAddrResolver, -// -// TokioRuntime, -// >; - -// type Delegate = DefaultDelegate; - -// #[tokio::test] -// async fn test_user_event_coalesce_basic() { -// let (tx, rx) = async_channel::unbounded(); -// let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); -// let coalescer = UserEventCoalescer::::new(); - -// let in_ = coalesced_event( -// tx, -// shutdown_rx, -// Duration::from_millis(20), -// Duration::from_millis(20), -// coalescer, -// ); - -// let send = vec![ -// UserEventMessage::default() -// .with_name("foo".into()) -// .with_cc(true) -// .with_ltime(1.into()), -// UserEventMessage::default() -// .with_name("foo".into()) -// .with_cc(true) -// .with_ltime(2.into()), -// UserEventMessage::default() -// .with_name("bar".into()) -// .with_cc(true) -// .with_ltime(2.into()) -// .with_payload("test1".into()), -// UserEventMessage::default() -// .with_name("bar".into()) -// .with_cc(true) -// .with_ltime(2.into()) -// .with_payload("test2".into()), -// ]; - -// for event in send { -// in_.send(CrateEvent::from(event)).await.unwrap(); -// } - -// let mut got_foo = false; -// let mut got_bar1 = false; -// let mut got_bar2 = false; - -// loop { -// futures::select! { -// _ = TokioRuntime::sleep(Duration::from_millis(40)).fuse() => break, -// event = rx.recv().fuse() => { -// let event = event.unwrap(); -// match event { -// CrateEvent::User(e) => { -// match e.name().as_str() { -// "foo" => { -// assert_eq!(e.ltime(), 2.into(), "bad ltime for foo"); -// got_foo = true; -// } -// "bar" => { -// assert_eq!(e.ltime(), 2.into(), "bad ltime for bar"); -// if e.payload().eq("test1".as_bytes()) { -// got_bar1 = true; -// } - -// if e.payload().eq("test2".as_bytes()) { -// got_bar2 = true; -// } -// } -// _ => unreachable!(), -// } -// } -// _ => unreachable!(), -// } -// } -// } -// } - -// assert!(got_foo && got_bar1 && got_bar2, "missing events"); -// } - -// #[test] -// fn test_user_event_coalesce_pass_through() { -// let cases = [ -// (CrateEvent::from(UserEventMessage::default()), false), -// ( -// CrateEvent::from(UserEventMessage::default().with_cc(true)), -// true, -// ), -// ( -// CrateEvent::from(MemberEvent { -// ty: MemberEventType::Join, -// members: TinyVec::new().into(), -// }), -// false, -// ), -// ( -// CrateEvent::from(MemberEvent { -// ty: MemberEventType::Leave, -// members: TinyVec::new().into(), -// }), -// false, -// ), -// ( -// CrateEvent::from(MemberEvent { -// ty: MemberEventType::Failed, -// members: TinyVec::new().into(), -// }), -// false, -// ), -// ]; - -// let coalescer = UserEventCoalescer::::new(); - -// for (idx, (event, should_coalesce)) in cases.iter().enumerate() { -// assert_eq!(coalescer.handle(event), *should_coalesce, "bad: {idx}"); -// } -// } -// } +#[cfg(all(test, feature = "test"))] +mod tests { + use std::net::SocketAddr; + + use agnostic_lite::tokio::TokioRuntime; + use memberlist_core::transport::{ + resolver::socket_addr::SocketAddrResolver, unimplemented::UnimplementedTransport, + }; + + use crate::{ + DefaultDelegate, + event::{MemberEvent, MemberEventType}, + }; + + use super::*; + + type Transport = UnimplementedTransport, TokioRuntime>; + + type Delegate = DefaultDelegate; + + #[tokio::test] + async fn test_user_event_coalesce_basic() { + let (tx, rx) = async_channel::unbounded(); + let (_shutdown_tx, shutdown_rx) = async_channel::bounded(1); + let coalescer = UserEventCoalescer::::new(); + + let in_ = coalesced_event( + tx, + shutdown_rx, + Duration::from_millis(20), + Duration::from_millis(20), + coalescer, + ); + + let send = vec![ + UserEventMessage::default() + .with_name("foo".into()) + .with_cc(true) + .with_ltime(1.into()), + UserEventMessage::default() + .with_name("foo".into()) + .with_cc(true) + .with_ltime(2.into()), + UserEventMessage::default() + .with_name("bar".into()) + .with_cc(true) + .with_ltime(2.into()) + .with_payload("test1".into()), + UserEventMessage::default() + .with_name("bar".into()) + .with_cc(true) + .with_ltime(2.into()) + .with_payload("test2".into()), + ]; + + for event in send { + in_.send(CrateEvent::from(event)).await.unwrap(); + } + + let mut got_foo = false; + let mut got_bar1 = false; + let mut got_bar2 = false; + + loop { + futures::select! { + _ = TokioRuntime::sleep(Duration::from_millis(40)).fuse() => break, + event = rx.recv().fuse() => { + let event = event.unwrap(); + match event { + CrateEvent::User(e) => { + match e.name().as_str() { + "foo" => { + assert_eq!(e.ltime(), 2.into(), "bad ltime for foo"); + got_foo = true; + } + "bar" => { + assert_eq!(e.ltime(), 2.into(), "bad ltime for bar"); + if e.payload().eq("test1".as_bytes()) { + got_bar1 = true; + } + + if e.payload().eq("test2".as_bytes()) { + got_bar2 = true; + } + } + _ => unreachable!(), + } + } + _ => unreachable!(), + } + } + } + } + + assert!(got_foo && got_bar1 && got_bar2, "missing events"); + } + + #[test] + fn test_user_event_coalesce_pass_through() { + let cases = [ + (CrateEvent::from(UserEventMessage::default()), false), + ( + CrateEvent::from(UserEventMessage::default().with_cc(true)), + true, + ), + ( + CrateEvent::from(MemberEvent { + ty: MemberEventType::Join, + members: TinyVec::new().into(), + }), + false, + ), + ( + CrateEvent::from(MemberEvent { + ty: MemberEventType::Leave, + members: TinyVec::new().into(), + }), + false, + ), + ( + CrateEvent::from(MemberEvent { + ty: MemberEventType::Failed, + members: TinyVec::new().into(), + }), + false, + ), + ]; + + let coalescer = UserEventCoalescer::::new(); + + for (idx, (event, should_coalesce)) in cases.iter().enumerate() { + assert_eq!(coalescer.handle(event), *should_coalesce, "bad: {idx}"); + } + } +}