junction_api/
backend.rs

1//! Backends are the logical target of network traffic. They have an identity and
2//! a load-balancing policy. See [Backend] to get started.
3
4use crate::{Error, Service};
5use serde::{Deserialize, Serialize};
6
7#[cfg(feature = "typeinfo")]
8use junction_typeinfo::TypeInfo;
9
10/// A Backend is uniquely identifiable by a combination of Service and port.
11///
12/// [Backend][crate::backend::Backend].
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
14#[cfg_attr(feature = "typeinfo", derive(TypeInfo))]
15pub struct BackendId {
16    /// The logical traffic target that this backend configures.
17    #[serde(flatten)]
18    pub service: Service,
19
20    /// The port backend traffic is sent on.
21    pub port: u16,
22}
23
24impl std::fmt::Display for BackendId {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        self.write_name(f)
27    }
28}
29
30impl std::str::FromStr for BackendId {
31    type Err = Error;
32
33    fn from_str(name: &str) -> Result<Self, Self::Err> {
34        let (name, port) = super::parse_port(name)?;
35        let port =
36            port.ok_or_else(|| Error::new_static("expected a fully qualified name with a port"))?;
37        let service = Service::from_str(name)?;
38
39        Ok(Self { service, port })
40    }
41}
42
43impl BackendId {
44    /// The cannonical name of this ID. This is an alias for the
45    /// [Display][std::fmt::Display] representation of this ID.
46    pub fn name(&self) -> String {
47        let mut buf = String::new();
48        self.write_name(&mut buf).unwrap();
49        buf
50    }
51
52    fn write_name(&self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
53        self.service.write_name(w)?;
54        write!(w, ":{port}", port = self.port)?;
55
56        Ok(())
57    }
58
59    #[doc(hidden)]
60    pub fn lb_config_route_name(&self) -> String {
61        let mut buf = String::new();
62        self.write_lb_config_route_name(&mut buf).unwrap();
63        buf
64    }
65
66    fn write_lb_config_route_name(&self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
67        self.service.write_lb_config_route_name(w)?;
68        write!(w, ":{port}", port = self.port)?;
69        Ok(())
70    }
71
72    #[doc(hidden)]
73    pub fn from_lb_config_route_name(name: &str) -> Result<Self, Error> {
74        let (name, port) = super::parse_port(name)?;
75        let port =
76            port.ok_or_else(|| Error::new_static("expected a fully qualified name with a port"))?;
77
78        let target = Service::from_lb_config_route_name(name)?;
79
80        Ok(Self {
81            service: target,
82            port,
83        })
84    }
85}
86
87/// A Backend is a logical target for network traffic.
88///
89/// A backend configures how all traffic for its `target` is handled. Any
90/// traffic routed to this backend will use the configured load balancing policy
91/// to spread traffic across available endpoints.
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93#[cfg_attr(feature = "typeinfo", derive(TypeInfo))]
94pub struct Backend {
95    /// A unique identifier for this backend.
96    pub id: BackendId,
97
98    /// How traffic to this target should be load balanced.
99    pub lb: LbPolicy,
100}
101
102// TODO: figure out how we want to support the filter_state/connection_properties style of hashing
103// based on source ip or grpc channel.
104//
105// TODO: add support for query parameter based hashing, which involves parsing query parameters,
106// which http::uri just doesn't do. switch the whole crate to url::Url or something.
107//
108// TODO: Random, Maglev
109//
110/// A policy describing how traffic to this target should be load balanced.
111#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
112#[serde(tag = "type")]
113#[cfg_attr(feature = "typeinfo", derive(TypeInfo))]
114pub enum LbPolicy {
115    /// A simple round robin load balancing policy. Endpoints are picked in sequential order, but
116    /// that order may vary client to client.
117    RoundRobin,
118
119    /// Use a ketama-style consistent hashing algorithm to route this request.
120    RingHash(RingHashParams),
121
122    /// No load balancing algorithm was specified. Clients may decide how load balancing happens
123    /// for this target.
124    #[default]
125    Unspecified,
126}
127
128impl LbPolicy {
129    /// Return `true` if this policy is [LbPolicy::Unspecified].
130    pub fn is_unspecified(&self) -> bool {
131        matches!(self, Self::Unspecified)
132    }
133}
134
135/// Policy for configuring a ketama-style consistent hashing algorithm.
136#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
137#[serde(deny_unknown_fields)]
138#[cfg_attr(feature = "typeinfo", derive(TypeInfo))]
139pub struct RingHashParams {
140    /// The minimum size of the hash ring
141    #[serde(default = "default_min_ring_size", alias = "minRingSize")]
142    pub min_ring_size: u32,
143
144    /// How to hash an outgoing request into the ring.
145    ///
146    /// Hash parameters are applied in order. If the request is missing an input, it has no effect
147    /// on the final hash. Hashing stops when only when all polices have been applied or a
148    /// `terminal` policy matches part of an incoming request.
149    ///
150    /// This allows configuring a fallback-style hash, where the value of `HeaderA` gets used,
151    /// falling back to the value of `HeaderB`.
152    ///
153    /// If no policies match, a random hash is generated for each request.
154    #[serde(default, skip_serializing_if = "Vec::is_empty", alias = "hashParams")]
155    pub hash_params: Vec<RequestHashPolicy>,
156}
157
158pub(crate) const fn default_min_ring_size() -> u32 {
159    1024
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
163#[cfg_attr(feature = "typeinfo", derive(TypeInfo))]
164pub struct RequestHashPolicy {
165    /// Whether to stop immediately after hashing this value.
166    ///
167    /// This is useful if you want to try to hash a value, and then fall back to
168    /// another as a default if it wasn't set.
169    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
170    pub terminal: bool,
171
172    #[serde(flatten)]
173    pub hasher: RequestHasher,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
177#[serde(tag = "type")]
178#[cfg_attr(feature = "typeinfo", derive(TypeInfo))]
179pub enum RequestHasher {
180    /// Hash the value of a header. If the header has multiple values, they will
181    /// all be used as hash input.
182    #[serde(alias = "header")]
183    Header {
184        /// The name of the header to use as hash input.
185        name: String,
186    },
187
188    /// Hash the value of an HTTP query parameter.
189    #[serde(alias = "query")]
190    QueryParam {
191        /// The name of the query parameter to hash
192        name: String,
193    },
194}
195
196#[cfg(test)]
197mod test {
198    use std::fmt::Debug;
199
200    use serde_json::json;
201
202    use super::*;
203
204    #[test]
205    fn test_lb_policy_json() {
206        assert_round_trip::<LbPolicy>(json!({
207            "type":"Unspecified",
208        }));
209        assert_round_trip::<LbPolicy>(json!({
210            "type":"RoundRobin",
211        }));
212        assert_round_trip::<LbPolicy>(json!({
213            "type":"RingHash",
214            "min_ring_size": 100,
215            "hash_params": [
216                {"type": "Header", "name": "x-user", "terminal": true},
217                {"type": "QueryParam", "name": "u"},
218            ]
219        }));
220    }
221
222    #[test]
223    fn test_backend_json() {
224        assert_round_trip::<Backend>(json!({
225            "id": {"type": "kube", "name": "foo", "namespace": "bar", "port": 789},
226            "lb": {
227                "type": "Unspecified",
228            },
229        }))
230    }
231
232    #[track_caller]
233    fn assert_round_trip<T: Debug + Serialize + for<'a> Deserialize<'a>>(value: serde_json::Value) {
234        let from_json: T = serde_json::from_value(value.clone()).expect("failed to deserialize");
235        let round_tripped = serde_json::to_value(&from_json).expect("failed to serialize");
236
237        assert_eq!(value, round_tripped, "serialized value should round-trip")
238    }
239}