1use crate::{endpoints::EndpointGroup, error::Trace, hash::thread_local_xxhash};
2use junction_api::backend::{Backend, LbPolicy, RequestHashPolicy, RequestHasher, RingHashParams};
3use smol_str::ToSmolStr;
4use std::{
5 net::SocketAddr,
6 sync::{
7 atomic::{AtomicUsize, Ordering},
8 RwLock,
9 },
10};
11
12#[derive(Debug)]
15pub struct BackendLb {
16 pub config: Backend,
17 pub load_balancer: LoadBalancer,
18}
19
20#[derive(Debug)]
21pub enum LoadBalancer {
22 RoundRobin(RoundRobinLb),
23 RingHash(RingHashLb),
24}
25
26impl LoadBalancer {
27 pub(crate) fn load_balance<'e>(
28 &self,
29 trace: &mut Trace,
30 endpoints: &'e EndpointGroup,
31 url: &crate::Url,
32 headers: &http::HeaderMap,
33 previous_addrs: &[SocketAddr],
34 ) -> Option<&'e SocketAddr> {
35 match self {
36 LoadBalancer::RoundRobin(lb) => lb.pick_endpoint(trace, endpoints, previous_addrs),
38 LoadBalancer::RingHash(lb) => {
40 lb.pick_endpoint(trace, endpoints, url, headers, &lb.config.hash_params)
41 }
42 }
43 }
44}
45
46impl LoadBalancer {
47 pub(crate) fn from_config(config: &LbPolicy) -> Self {
48 match config {
49 LbPolicy::RoundRobin => LoadBalancer::RoundRobin(RoundRobinLb::default()),
50 LbPolicy::RingHash(x) => LoadBalancer::RingHash(RingHashLb::new(x)),
51 LbPolicy::Unspecified => LoadBalancer::RoundRobin(RoundRobinLb::default()),
52 }
53 }
54}
55
56#[derive(Debug, Default)]
62pub struct RoundRobinLb {
63 idx: AtomicUsize,
64}
65
66impl RoundRobinLb {
67 fn pick_endpoint<'e>(
69 &self,
70 trace: &mut Trace,
71 endpoint_group: &'e EndpointGroup,
72 previous_addrs: &[SocketAddr],
73 ) -> Option<&'e SocketAddr> {
74 let _ = previous_addrs;
81
82 let idx = self.idx.fetch_add(1, Ordering::SeqCst) % endpoint_group.len();
83 let addr = endpoint_group.nth(idx);
84
85 trace.load_balance("ROUND_ROBIN", addr, Vec::new());
86
87 addr
88 }
89}
90
91#[derive(Debug)]
99pub struct RingHashLb {
100 config: RingHashParams,
101 ring: RwLock<Ring>,
102}
103
104#[derive(Debug, Clone, PartialEq, Eq)]
105struct RingEntry {
106 hash: u64,
107 idx: usize,
108}
109
110impl RingHashLb {
111 fn new(config: &RingHashParams) -> Self {
112 Self {
113 config: config.clone(),
114 ring: RwLock::new(Ring {
115 eg_hash: 0,
116 entries: Vec::with_capacity(config.min_ring_size as usize),
117 }),
118 }
119 }
120
121 fn pick_endpoint<'e>(
122 &self,
123 trace: &mut Trace,
124 endpoints: &'e EndpointGroup,
125 url: &crate::Url,
126 headers: &http::HeaderMap,
127 hash_params: &Vec<RequestHashPolicy>,
128 ) -> Option<&'e SocketAddr> {
129 let request_hash =
130 hash_request(hash_params, url, headers).unwrap_or_else(crate::rand::random);
131
132 let endpoint_idx = self.with_ring(endpoints, |r| r.pick(request_hash))?;
133 let addr = endpoints.nth(endpoint_idx);
134
135 trace.load_balance("RING_HASH", addr, vec![("hash", request_hash.to_smolstr())]);
136
137 addr
138 }
139
140 fn with_ring<F, T>(&self, endpoint_group: &EndpointGroup, mut cb: F) -> T
149 where
150 F: FnMut(&Ring) -> T,
151 {
152 let ring = self.ring.read().unwrap();
157 if ring.eg_hash == endpoint_group.hash {
158 return cb(&ring);
159 }
160 std::mem::drop(ring);
161
162 let mut ring = self.ring.write().unwrap();
164 ring.rebuild(self.config.min_ring_size as usize, endpoint_group);
165 cb(&ring)
166 }
167}
168
169#[derive(Debug)]
170struct Ring {
171 eg_hash: u64,
174 entries: Vec<RingEntry>,
175}
176
177impl Ring {
178 fn rebuild(&mut self, min_size: usize, endpoint_group: &EndpointGroup) {
179 let endpoint_count = endpoint_group.len();
184 let repeats = usize::max((min_size as f64 / endpoint_count as f64).ceil() as usize, 1);
185 let ring_size = repeats * endpoint_count;
186
187 self.entries.clear();
188 self.entries.reserve(ring_size);
189
190 for (idx, endpoint) in endpoint_group.iter().enumerate() {
191 for i in 0..repeats {
192 let hash = thread_local_xxhash::hash(&(endpoint, i));
201 self.entries.push(RingEntry { hash, idx });
202 }
203 }
204
205 self.eg_hash = endpoint_group.hash;
206 self.entries.sort_by_key(|e| e.hash);
207 }
208
209 fn pick(&self, endpoint_hash: u64) -> Option<usize> {
210 if self.entries.is_empty() {
211 return None;
212 }
213
214 let entry_idx = self.entries.partition_point(|e| e.hash < endpoint_hash);
221 let entry_idx = entry_idx % self.entries.len();
222 Some(self.entries[entry_idx].idx)
223 }
224}
225
226pub(crate) fn hash_request(
236 hash_policies: &Vec<RequestHashPolicy>,
237 url: &crate::Url,
238 headers: &http::HeaderMap,
239) -> Option<u64> {
240 let mut hash: Option<u64> = None;
241
242 for hash_policy in hash_policies {
243 if let Some(new_hash) = hash_component(hash_policy, url, headers) {
244 hash = Some(match hash {
245 Some(hash) => hash.rotate_left(1) ^ new_hash,
246 None => new_hash,
247 });
248
249 if hash_policy.terminal {
250 break;
251 }
252 }
253 }
254
255 hash
256}
257
258fn hash_component(
259 policy: &RequestHashPolicy,
260 url: &crate::Url,
261 headers: &http::HeaderMap,
262) -> Option<u64> {
263 match &policy.hasher {
264 RequestHasher::Header { name } => {
265 let mut header_values: Vec<_> = headers
266 .get_all(name)
267 .iter()
268 .map(http::HeaderValue::as_bytes)
269 .collect();
270
271 if header_values.is_empty() {
272 None
273 } else {
274 header_values.sort();
276 Some(thread_local_xxhash::hash_iter(header_values))
277 }
278 }
279 RequestHasher::QueryParam { ref name } => url.query().map(|query| {
280 let matching_vals = form_urlencoded::parse(query.as_bytes())
281 .filter_map(|(param, value)| (¶m == name).then_some(value));
282 thread_local_xxhash::hash_iter(matching_vals)
283 }),
284 }
285}
286
287#[cfg(test)]
288mod test_ring_hash {
289 use crate::endpoints::Locality;
290
291 use super::*;
292
293 #[test]
294 fn test_rebuild_ring() {
295 let mut ring = Ring {
296 eg_hash: 0,
297 entries: Vec::new(),
298 };
299
300 ring.rebuild(0, &endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80"]));
302
303 assert_eq!(ring.eg_hash, 123);
304 assert_eq!(ring.entries.len(), 2);
305 assert_eq!(ring_indexes(&ring), (0..2).collect::<Vec<_>>());
306 assert_hashes_unique(&ring);
307
308 let first_ring = ring.entries.clone();
309
310 ring.rebuild(
312 0,
313 &endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80", "1.1.1.3:80"]),
314 );
315
316 assert_eq!(ring.eg_hash, 123);
317 assert_eq!(ring.entries.len(), 3);
318 assert_eq!(ring_indexes(&ring), (0..3).collect::<Vec<_>>());
319 assert_hashes_unique(&ring);
320
321 let second_ring: Vec<_> = ring
324 .entries
325 .iter()
326 .filter(|e| e.idx != 2)
327 .cloned()
328 .collect();
329 assert_eq!(first_ring, second_ring);
330 }
331
332 #[test]
333 fn test_rebuild_ring_min_size() {
334 let mut ring = Ring {
335 eg_hash: 0,
336 entries: Vec::new(),
337 };
338
339 ring.rebuild(
341 1024,
342 &endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80", "1.1.1.3:80"]),
343 );
344
345 assert_eq!(ring.entries.len(), 1026);
347
348 let mut counts = [0usize; 3];
350 for entry in &ring.entries {
351 counts[entry.idx] += 1;
352 }
353 assert!(counts.iter().all(|&c| c == 342));
354
355 assert_hashes_unique(&ring);
357 }
358
359 #[test]
360 fn test_pick() {
361 let mut ring = Ring {
362 eg_hash: 0,
363 entries: vec![],
364 };
365 ring.rebuild(
366 0,
367 &EndpointGroup::new(
368 [(
369 Locality::Unknown,
370 vec![
371 "1.1.1.1:80".parse().unwrap(),
372 "1.1.1.2:80".parse().unwrap(),
373 "1.1.1.3:80".parse().unwrap(),
374 ],
375 )]
376 .into(),
377 ),
378 );
379
380 let hashes_to_first = [0, ring.entries[0].hash, ring.entries[2].hash + 1];
383 for hash in hashes_to_first {
384 assert_eq!(ring.pick(hash), Some(ring.entries[0].idx),)
385 }
386
387 let hashes_to_last = [ring.entries[2].hash - 1, ring.entries[2].hash];
388 for hash in hashes_to_last {
389 assert_eq!(ring.pick(hash), Some(ring.entries[2].idx));
390 }
391 }
392
393 fn ring_indexes(r: &Ring) -> Vec<usize> {
394 let mut indexes: Vec<_> = r.entries.iter().map(|e| e.idx).collect();
395 indexes.sort();
396 indexes
397 }
398
399 fn endpoint_group(hash: u64, addrs: impl IntoIterator<Item = &'static str>) -> EndpointGroup {
400 let addrs = addrs.into_iter().map(|s| s.parse().unwrap()).collect();
401
402 let mut eg = EndpointGroup::new([(Locality::Unknown, addrs)].into());
403 eg.hash = hash;
404
405 eg
406 }
407
408 fn assert_hashes_unique(r: &Ring) {
409 let mut hashes: Vec<_> = r.entries.iter().map(|e| e.hash).collect();
410 hashes.sort();
411 hashes.dedup();
412 assert_eq!(hashes.len(), r.entries.len());
413 }
414}