junction_core/
load_balancer.rs

1use crate::{endpoints::EndpointGroup, error::Trace, hash::thread_local_xxhash};
2use junction_api::backend::{Backend, LbPolicy, RequestHashPolicy, RequestHasher, RingHashParams};
3use smol_str::ToSmolStr;
4use std::{
5    net::SocketAddr,
6    sync::{
7        atomic::{AtomicUsize, Ordering},
8        RwLock,
9    },
10};
11
12/// A [Backend][junction_api::backend::Backend] and the [LoadBalancer] it's
13/// configured with.
14#[derive(Debug)]
15pub struct BackendLb {
16    pub config: Backend,
17    pub load_balancer: LoadBalancer,
18}
19
20#[derive(Debug)]
21pub enum LoadBalancer {
22    RoundRobin(RoundRobinLb),
23    RingHash(RingHashLb),
24}
25
26impl LoadBalancer {
27    pub(crate) fn load_balance<'e>(
28        &self,
29        trace: &mut Trace,
30        endpoints: &'e EndpointGroup,
31        url: &crate::Url,
32        headers: &http::HeaderMap,
33        previous_addrs: &[SocketAddr],
34    ) -> Option<&'e SocketAddr> {
35        match self {
36            // RoundRobin skips previously picked addrs and ignores context
37            LoadBalancer::RoundRobin(lb) => lb.pick_endpoint(trace, endpoints, previous_addrs),
38            // RingHash needs context but doesn't care about history
39            LoadBalancer::RingHash(lb) => {
40                lb.pick_endpoint(trace, endpoints, url, headers, &lb.config.hash_params)
41            }
42        }
43    }
44}
45
46impl LoadBalancer {
47    pub(crate) fn from_config(config: &LbPolicy) -> Self {
48        match config {
49            LbPolicy::RoundRobin => LoadBalancer::RoundRobin(RoundRobinLb::default()),
50            LbPolicy::RingHash(x) => LoadBalancer::RingHash(RingHashLb::new(x)),
51            LbPolicy::Unspecified => LoadBalancer::RoundRobin(RoundRobinLb::default()),
52        }
53    }
54}
55
56// TODO: when doing weighted round robin, it's worth adapting the GRPC
57// scheduling in static_stride_scheduler.cc instead of inventing a new technique
58// ourselves.
59//
60// src/core/load_balancing/weighted_round_robin/static_stride_scheduler.cc
61#[derive(Debug, Default)]
62pub struct RoundRobinLb {
63    idx: AtomicUsize,
64}
65
66impl RoundRobinLb {
67    // FIXME: actually use locality
68    fn pick_endpoint<'e>(
69        &self,
70        trace: &mut Trace,
71        endpoint_group: &'e EndpointGroup,
72        previous_addrs: &[SocketAddr],
73    ) -> Option<&'e SocketAddr> {
74        // TODO: actually use previous addrs to pick a new address. have to
75        // decide if we return anything if all addresses have previously been
76        // picked, or how many dupes we allow, etc. Envoy has a policy for this,
77        // but we'd prefer to do something simpler by default.
78        //
79        // https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/route/v3/route_components.proto#config-route-v3-retrypolicy
80        let _ = previous_addrs;
81
82        let idx = self.idx.fetch_add(1, Ordering::SeqCst) % endpoint_group.len();
83        let addr = endpoint_group.nth(idx);
84
85        trace.load_balance("ROUND_ROBIN", addr, Vec::new());
86
87        addr
88    }
89}
90
91/// A ring hash LB using Ketama hashing, roughly compatible with the GRPC and
92/// Envoy implementations.
93///
94/// Like the Envoy and gRPC implementations, this load balancer ignores locality
95/// and flattens all visible endpoints into a single hash ring. Unlike GRPC and
96/// Envoy, this load balancer ignores endpoint weights.
97///
98#[derive(Debug)]
99pub struct RingHashLb {
100    config: RingHashParams,
101    ring: RwLock<Ring>,
102}
103
104#[derive(Debug, Clone, PartialEq, Eq)]
105struct RingEntry {
106    hash: u64,
107    idx: usize,
108}
109
110impl RingHashLb {
111    fn new(config: &RingHashParams) -> Self {
112        Self {
113            config: config.clone(),
114            ring: RwLock::new(Ring {
115                eg_hash: 0,
116                entries: Vec::with_capacity(config.min_ring_size as usize),
117            }),
118        }
119    }
120
121    fn pick_endpoint<'e>(
122        &self,
123        trace: &mut Trace,
124        endpoints: &'e EndpointGroup,
125        url: &crate::Url,
126        headers: &http::HeaderMap,
127        hash_params: &Vec<RequestHashPolicy>,
128    ) -> Option<&'e SocketAddr> {
129        let request_hash =
130            hash_request(hash_params, url, headers).unwrap_or_else(crate::rand::random);
131
132        let endpoint_idx = self.with_ring(endpoints, |r| r.pick(request_hash))?;
133        let addr = endpoints.nth(endpoint_idx);
134
135        trace.load_balance("RING_HASH", addr, vec![("hash", request_hash.to_smolstr())]);
136
137        addr
138    }
139
140    // if you're reading this you might wonder why this takes a callback:
141    //
142    // the answer is that std's RWLock isn't upgradeable or downgradeable, so
143    // it's not easy to have an RAII guard that starts with a read lock and
144    // transparently upgrades to write when you need mut access.
145    //
146    // instead of figuring that out, this fn does the read to write upgrade and
147    // you pass the callback. easy peasy.
148    fn with_ring<F, T>(&self, endpoint_group: &EndpointGroup, mut cb: F) -> T
149    where
150        F: FnMut(&Ring) -> T,
151    {
152        // try to just use the existing ring
153        //
154        // explicitly drop the guard at the end so we can't get confused and
155        // try to upgrade and deadlock ourselves.
156        let ring = self.ring.read().unwrap();
157        if ring.eg_hash == endpoint_group.hash {
158            return cb(&ring);
159        }
160        std::mem::drop(ring);
161
162        // write path:
163        let mut ring = self.ring.write().unwrap();
164        ring.rebuild(self.config.min_ring_size as usize, endpoint_group);
165        cb(&ring)
166    }
167}
168
169#[derive(Debug)]
170struct Ring {
171    // The hash of the EndpointGroup used to build the Ring. This is slightly
172    // more stable than ResourceVersion, but could be changed to that.
173    eg_hash: u64,
174    entries: Vec<RingEntry>,
175}
176
177impl Ring {
178    fn rebuild(&mut self, min_size: usize, endpoint_group: &EndpointGroup) {
179        // before factoring in weights, we're doing some quick math to get a
180        // multiple of the endpoint count size to fill the ring.
181        //
182        // once we factor in weights, this is going to get way nastier.
183        let endpoint_count = endpoint_group.len();
184        let repeats = usize::max((min_size as f64 / endpoint_count as f64).ceil() as usize, 1);
185        let ring_size = repeats * endpoint_count;
186
187        self.entries.clear();
188        self.entries.reserve(ring_size);
189
190        for (idx, endpoint) in endpoint_group.iter().enumerate() {
191            for i in 0..repeats {
192                // we're using both the endpoint address and the index as the
193                // hash key. envoy stringifies things to do this but we don't
194                // need to match exactly.
195                //
196                // https://github.com/envoyproxy/envoy/blob/66cc2175fe5044117c9f00af8d09293012778000/source/extensions/load_balancing_policies/ring_hash/ring_hash_lb.cc#L195-L205
197                //
198                // we're depending here on derive(Hash) for SocketAddr being
199                // stable across rustc versions
200                let hash = thread_local_xxhash::hash(&(endpoint, i));
201                self.entries.push(RingEntry { hash, idx });
202            }
203        }
204
205        self.eg_hash = endpoint_group.hash;
206        self.entries.sort_by_key(|e| e.hash);
207    }
208
209    fn pick(&self, endpoint_hash: u64) -> Option<usize> {
210        if self.entries.is_empty() {
211            return None;
212        }
213
214        // the envoy/grpc implementations use a binary search cago culted from
215        // ketama_get_server in the original ketama implementation.
216        //
217        // instead of doing that, use the stdlib. partition_point returns the
218        // first idx for which the endpoint hash is larger than the
219        // endpoint_hash using a binary search.
220        let entry_idx = self.entries.partition_point(|e| e.hash < endpoint_hash);
221        let entry_idx = entry_idx % self.entries.len();
222        Some(self.entries[entry_idx].idx)
223    }
224}
225
226/// Hash an outgoing request based on a set of hash policies.
227///
228/// Like Envoy and gRPC, multiple hash policies are combined by applying a
229/// bitwise left-rotate to the previous value and xor-ing the new value into
230/// the previous value.
231///
232/// See:
233/// - https://github.com/grpc/proposal/blob/master/A42-xds-ring-hash-lb-policy.md#xds-api-fields
234/// - https://github.com/envoyproxy/envoy/blob/main/source/common/http/hash_policy.cc#L236-L257
235pub(crate) fn hash_request(
236    hash_policies: &Vec<RequestHashPolicy>,
237    url: &crate::Url,
238    headers: &http::HeaderMap,
239) -> Option<u64> {
240    let mut hash: Option<u64> = None;
241
242    for hash_policy in hash_policies {
243        if let Some(new_hash) = hash_component(hash_policy, url, headers) {
244            hash = Some(match hash {
245                Some(hash) => hash.rotate_left(1) ^ new_hash,
246                None => new_hash,
247            });
248
249            if hash_policy.terminal {
250                break;
251            }
252        }
253    }
254
255    hash
256}
257
258fn hash_component(
259    policy: &RequestHashPolicy,
260    url: &crate::Url,
261    headers: &http::HeaderMap,
262) -> Option<u64> {
263    match &policy.hasher {
264        RequestHasher::Header { name } => {
265            let mut header_values: Vec<_> = headers
266                .get_all(name)
267                .iter()
268                .map(http::HeaderValue::as_bytes)
269                .collect();
270
271            if header_values.is_empty() {
272                None
273            } else {
274                // sort values so that "foo,bar" and "bar,foo" hash to the same value
275                header_values.sort();
276                Some(thread_local_xxhash::hash_iter(header_values))
277            }
278        }
279        RequestHasher::QueryParam { ref name } => url.query().map(|query| {
280            let matching_vals = form_urlencoded::parse(query.as_bytes())
281                .filter_map(|(param, value)| (&param == name).then_some(value));
282            thread_local_xxhash::hash_iter(matching_vals)
283        }),
284    }
285}
286
287#[cfg(test)]
288mod test_ring_hash {
289    use crate::endpoints::Locality;
290
291    use super::*;
292
293    #[test]
294    fn test_rebuild_ring() {
295        let mut ring = Ring {
296            eg_hash: 0,
297            entries: Vec::new(),
298        };
299
300        // rebuild a ring with no min size
301        ring.rebuild(0, &endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80"]));
302
303        assert_eq!(ring.eg_hash, 123);
304        assert_eq!(ring.entries.len(), 2);
305        assert_eq!(ring_indexes(&ring), (0..2).collect::<Vec<_>>());
306        assert_hashes_unique(&ring);
307
308        let first_ring = ring.entries.clone();
309
310        // rebuild the ring with new ips
311        ring.rebuild(
312            0,
313            &endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80", "1.1.1.3:80"]),
314        );
315
316        assert_eq!(ring.eg_hash, 123);
317        assert_eq!(ring.entries.len(), 3);
318        assert_eq!(ring_indexes(&ring), (0..3).collect::<Vec<_>>());
319        assert_hashes_unique(&ring);
320
321        // the ring with the entry for the new address (1.1.1.3) removed should
322        // be the same as the ring for the prior group.
323        let second_ring: Vec<_> = ring
324            .entries
325            .iter()
326            .filter(|e| e.idx != 2)
327            .cloned()
328            .collect();
329        assert_eq!(first_ring, second_ring);
330    }
331
332    #[test]
333    fn test_rebuild_ring_min_size() {
334        let mut ring = Ring {
335            eg_hash: 0,
336            entries: Vec::new(),
337        };
338
339        // rebuild a ring with no min size
340        ring.rebuild(
341            1024,
342            &endpoint_group(123, ["1.1.1.1:80", "1.1.1.2:80", "1.1.1.3:80"]),
343        );
344
345        // 1026 is the largest multiple of 3 larger than 1024
346        assert_eq!(ring.entries.len(), 1026);
347
348        // every idx should be repeated the 1026/3 times.
349        let mut counts = [0usize; 3];
350        for entry in &ring.entries {
351            counts[entry.idx] += 1;
352        }
353        assert!(counts.iter().all(|&c| c == 342));
354
355        // every hash should be unique
356        assert_hashes_unique(&ring);
357    }
358
359    #[test]
360    fn test_pick() {
361        let mut ring = Ring {
362            eg_hash: 0,
363            entries: vec![],
364        };
365        ring.rebuild(
366            0,
367            &EndpointGroup::new(
368                [(
369                    Locality::Unknown,
370                    vec![
371                        "1.1.1.1:80".parse().unwrap(),
372                        "1.1.1.2:80".parse().unwrap(),
373                        "1.1.1.3:80".parse().unwrap(),
374                    ],
375                )]
376                .into(),
377            ),
378        );
379
380        // anything less than or eq to the first hash, or greater than the last
381        // hash should hash to the start the ring.
382        let hashes_to_first = [0, ring.entries[0].hash, ring.entries[2].hash + 1];
383        for hash in hashes_to_first {
384            assert_eq!(ring.pick(hash), Some(ring.entries[0].idx),)
385        }
386
387        let hashes_to_last = [ring.entries[2].hash - 1, ring.entries[2].hash];
388        for hash in hashes_to_last {
389            assert_eq!(ring.pick(hash), Some(ring.entries[2].idx));
390        }
391    }
392
393    fn ring_indexes(r: &Ring) -> Vec<usize> {
394        let mut indexes: Vec<_> = r.entries.iter().map(|e| e.idx).collect();
395        indexes.sort();
396        indexes
397    }
398
399    fn endpoint_group(hash: u64, addrs: impl IntoIterator<Item = &'static str>) -> EndpointGroup {
400        let addrs = addrs.into_iter().map(|s| s.parse().unwrap()).collect();
401
402        let mut eg = EndpointGroup::new([(Locality::Unknown, addrs)].into());
403        eg.hash = hash;
404
405        eg
406    }
407
408    fn assert_hashes_unique(r: &Ring) {
409        let mut hashes: Vec<_> = r.entries.iter().map(|e| e.hash).collect();
410        hashes.sort();
411        hashes.dedup();
412        assert_eq!(hashes.len(), r.entries.len());
413    }
414}