1#![warn(missing_debug_implementations, missing_docs, unreachable_pub)]
5
6use crate::filter::AsyncFilter;
7use futures_util::future;
8use pin_project_lite::pin_project;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use std::{
12 pin::Pin,
13 task::{Context, Poll},
14};
15use tracing::error;
16
17mod delay;
18mod latency;
19mod rotating_histogram;
20mod select;
21
22use delay::Delay;
23use latency::Latency;
24use rotating_histogram::RotatingHistogram;
25use select::Select;
26
27type Histo = Arc<Mutex<RotatingHistogram>>;
28type Service<S, P> = select::Select<
29 SelectPolicy<P>,
30 Latency<Histo, S>,
31 Delay<DelayPolicy, AsyncFilter<Latency<Histo, S>, PolicyPredicate<P>>>,
32>;
33
34#[derive(Debug)]
38pub struct Hedge<S, P>(Service<S, P>);
39
40pin_project! {
41 #[derive(Debug)]
45 pub struct Future<S, Request>
46 where
47 S: tower_service::Service<Request>,
48 {
49 #[pin]
50 inner: S::Future,
51 }
52}
53
54pub trait Policy<Request> {
57 fn clone_request(&self, req: &Request) -> Option<Request>;
59
60 fn can_retry(&self, req: &Request) -> bool;
62}
63
64#[doc(hidden)]
67#[derive(Clone, Debug)]
68pub struct PolicyPredicate<P>(P);
69
70#[doc(hidden)]
71#[derive(Debug)]
72pub struct DelayPolicy {
73 histo: Histo,
74 latency_percentile: f32,
75}
76
77#[doc(hidden)]
78#[derive(Debug)]
79pub struct SelectPolicy<P> {
80 policy: P,
81 histo: Histo,
82 min_data_points: u64,
83}
84
85impl<S, P> Hedge<S, P> {
86 pub fn new<Request>(
88 service: S,
89 policy: P,
90 min_data_points: u64,
91 latency_percentile: f32,
92 period: Duration,
93 ) -> Hedge<S, P>
94 where
95 S: tower_service::Service<Request> + Clone,
96 S::Error: Into<crate::BoxError>,
97 P: Policy<Request> + Clone,
98 {
99 let histo = Arc::new(Mutex::new(RotatingHistogram::new(period)));
100 Self::new_with_histo(service, policy, min_data_points, latency_percentile, histo)
101 }
102
103 pub fn new_with_mock_latencies<Request>(
106 service: S,
107 policy: P,
108 min_data_points: u64,
109 latency_percentile: f32,
110 period: Duration,
111 latencies_ms: &[u64],
112 ) -> Hedge<S, P>
113 where
114 S: tower_service::Service<Request> + Clone,
115 S::Error: Into<crate::BoxError>,
116 P: Policy<Request> + Clone,
117 {
118 let histo = Arc::new(Mutex::new(RotatingHistogram::new(period)));
119 {
120 let mut locked = histo.lock().unwrap();
121 for latency in latencies_ms.iter() {
122 locked.read().record(*latency).unwrap();
123 }
124 }
125 Self::new_with_histo(service, policy, min_data_points, latency_percentile, histo)
126 }
127
128 fn new_with_histo<Request>(
129 service: S,
130 policy: P,
131 min_data_points: u64,
132 latency_percentile: f32,
133 histo: Histo,
134 ) -> Hedge<S, P>
135 where
136 S: tower_service::Service<Request> + Clone,
137 S::Error: Into<crate::BoxError>,
138 P: Policy<Request> + Clone,
139 {
140 let recorded_a = Latency::new(histo.clone(), service.clone());
143 let recorded_b = Latency::new(histo.clone(), service);
144
145 let filtered = AsyncFilter::new(recorded_b, PolicyPredicate(policy.clone()));
147
148 let delay_policy = DelayPolicy {
151 histo: histo.clone(),
152 latency_percentile,
153 };
154 let delayed = Delay::new(delay_policy, filtered);
155
156 let select_policy = SelectPolicy {
159 policy,
160 histo,
161 min_data_points,
162 };
163 Hedge(Select::new(select_policy, recorded_a, delayed))
164 }
165}
166
167impl<S, P, Request> tower_service::Service<Request> for Hedge<S, P>
168where
169 S: tower_service::Service<Request> + Clone,
170 S::Error: Into<crate::BoxError>,
171 P: Policy<Request> + Clone,
172{
173 type Response = S::Response;
174 type Error = crate::BoxError;
175 type Future = Future<Service<S, P>, Request>;
176
177 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
178 self.0.poll_ready(cx)
179 }
180
181 fn call(&mut self, request: Request) -> Self::Future {
182 Future {
183 inner: self.0.call(request),
184 }
185 }
186}
187
188impl<S, Request> std::future::Future for Future<S, Request>
189where
190 S: tower_service::Service<Request>,
191 S::Error: Into<crate::BoxError>,
192{
193 type Output = Result<S::Response, crate::BoxError>;
194
195 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
196 self.project().inner.poll(cx).map_err(Into::into)
197 }
198}
199
200const NANOS_PER_MILLI: u32 = 1_000_000;
202const MILLIS_PER_SEC: u64 = 1_000;
203fn millis(duration: Duration) -> u64 {
204 let millis = (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI;
206 duration
207 .as_secs()
208 .saturating_mul(MILLIS_PER_SEC)
209 .saturating_add(u64::from(millis))
210}
211
212impl latency::Record for Histo {
213 fn record(&mut self, latency: Duration) {
214 let mut locked = self.lock().unwrap();
215 locked.write().record(millis(latency)).unwrap_or_else(|e| {
216 error!("Failed to write to hedge histogram: {:?}", e);
217 })
218 }
219}
220
221impl<P, Request> crate::filter::AsyncPredicate<Request> for PolicyPredicate<P>
222where
223 P: Policy<Request>,
224{
225 type Future = future::Either<
226 future::Ready<Result<Request, crate::BoxError>>,
227 future::Pending<Result<Request, crate::BoxError>>,
228 >;
229 type Request = Request;
230
231 fn check(&mut self, request: Request) -> Self::Future {
232 if self.0.can_retry(&request) {
233 future::Either::Left(future::ready(Ok(request)))
234 } else {
235 future::Either::Right(future::pending())
240 }
241 }
242}
243
244impl<Request> delay::Policy<Request> for DelayPolicy {
245 fn delay(&self, _req: &Request) -> Duration {
246 let mut locked = self.histo.lock().unwrap();
247 let millis = locked
248 .read()
249 .value_at_quantile(self.latency_percentile.into());
250 Duration::from_millis(millis)
251 }
252}
253
254impl<P, Request> select::Policy<Request> for SelectPolicy<P>
255where
256 P: Policy<Request>,
257{
258 fn clone_request(&self, req: &Request) -> Option<Request> {
259 self.policy.clone_request(req).filter(|_| {
260 let mut locked = self.histo.lock().unwrap();
261 locked.read().len() >= self.min_data_points
264 })
265 }
266}