junction_core/
endpoints.rs1use junction_api::{
2 backend::BackendId,
3 http::{RouteRetry, RouteTimeouts},
4};
5use std::{collections::BTreeMap, net::SocketAddr, sync::Arc};
6
7use crate::{error::Trace, hash::thread_local_xxhash, HttpResult};
8
9#[derive(Debug, Clone)]
14pub struct Endpoint {
15 pub(crate) method: http::Method,
17 pub(crate) url: crate::Url,
18 pub(crate) headers: http::HeaderMap,
19
20 pub(crate) backend: BackendId,
24 pub(crate) address: SocketAddr,
25 pub(crate) timeouts: Option<RouteTimeouts>,
26 pub(crate) retry: Option<RouteRetry>,
27
28 pub(crate) trace: Trace,
30 pub(crate) previous_addrs: Vec<SocketAddr>,
31}
32
33impl Endpoint {
34 pub fn method(&self) -> &http::Method {
35 &self.method
36 }
37
38 pub fn url(&self) -> &crate::Url {
39 &self.url
40 }
41
42 pub fn headers(&self) -> &http::HeaderMap {
43 &self.headers
44 }
45
46 pub fn addr(&self) -> SocketAddr {
47 self.address
48 }
49
50 pub fn timeouts(&self) -> &Option<RouteTimeouts> {
51 &self.timeouts
52 }
53
54 pub fn retry(&self) -> &Option<RouteRetry> {
55 &self.retry
56 }
57
58 pub(crate) fn should_retry(&self, result: HttpResult) -> bool {
59 let Some(retry) = &self.retry else {
60 return false;
61 };
62 let Some(allowed) = &retry.attempts else {
63 return false;
64 };
65 let allowed = *allowed as usize;
66
67 match result {
68 HttpResult::StatusError(code) if !retry.codes.contains(&code.as_u16()) => return false,
69 _ => (),
70 }
71
72 let attempts = self.previous_addrs.len() + 1;
75
76 attempts < allowed
77 }
78
79 pub fn print_trace(&self) {
81 let start = self.trace.start();
82 let mut phase = None;
83
84 for event in self.trace.events() {
85 if phase != Some(event.phase) {
86 eprintln!("{:?}", event.phase);
87 phase = Some(event.phase);
88 }
89
90 let elapsed = event.at.duration_since(start).as_secs_f64();
91 eprint!(" {elapsed:.06}: {name:>16?}", name = event.kind);
92 if !event.kv.is_empty() {
93 eprint!(":");
94
95 for (k, v) in &event.kv {
96 eprint!(" {k}={v}")
97 }
98 }
99 eprintln!();
100 }
101 }
102}
103
104#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
105pub(crate) enum Locality {
106 Unknown,
107 #[allow(unused)]
108 Known(LocalityInfo),
109}
110
111#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
112pub(crate) struct LocalityInfo {
113 pub(crate) region: String,
114 pub(crate) zone: String,
115}
116
117pub struct EndpointIter {
119 endpoint_group: Arc<EndpointGroup>,
120}
121
122impl From<Arc<EndpointGroup>> for EndpointIter {
123 fn from(endpoint_group: Arc<EndpointGroup>) -> Self {
124 Self { endpoint_group }
125 }
126}
127
128impl EndpointIter {
131 pub fn addrs(&self) -> impl Iterator<Item = &SocketAddr> {
134 self.endpoint_group.iter()
135 }
136}
137
138#[derive(Debug, Default, Hash, PartialEq, Eq)]
139pub(crate) struct EndpointGroup {
140 pub(crate) hash: u64,
141 endpoints: BTreeMap<Locality, Vec<SocketAddr>>,
142}
143
144impl EndpointGroup {
145 pub(crate) fn new(endpoints: BTreeMap<Locality, Vec<SocketAddr>>) -> Self {
146 let hash = thread_local_xxhash::hash(&endpoints);
147 Self { hash, endpoints }
148 }
149
150 pub(crate) fn from_dns_addrs(addrs: impl IntoIterator<Item = SocketAddr>) -> Self {
151 let mut endpoints = BTreeMap::new();
152 let endpoint_addrs = addrs.into_iter().collect();
153 endpoints.insert(Locality::Unknown, endpoint_addrs);
154
155 Self::new(endpoints)
156 }
157
158 pub(crate) fn len(&self) -> usize {
159 self.endpoints.values().map(|v| v.len()).sum()
160 }
161
162 pub(crate) fn iter(&self) -> impl Iterator<Item = &SocketAddr> {
168 self.endpoints.values().flatten()
169 }
170
171 pub(crate) fn nth(&self, n: usize) -> Option<&SocketAddr> {
173 let mut n = n;
174 for endpoints in self.endpoints.values() {
175 if n < endpoints.len() {
176 return Some(&endpoints[n]);
177 }
178 n -= endpoints.len();
179 }
180
181 None
182 }
183}
184
185#[cfg(test)]
186mod test {
187 use std::net::Ipv4Addr;
188
189 use http::StatusCode;
190 use junction_api::{Duration, Service};
191
192 use crate::Url;
193
194 use super::*;
195
196 #[test]
197 fn test_endpoint_should_retry_no_policy() {
198 let mut endpoint = new_endpoint();
199 endpoint.retry = None;
200
201 assert!(!endpoint.should_retry(HttpResult::StatusFailed));
202 assert!(!endpoint.should_retry(HttpResult::StatusError(
203 http::StatusCode::SERVICE_UNAVAILABLE
204 )));
205 }
206
207 #[test]
208 fn test_endpoint_should_retry_with_policy() {
209 let mut endpoint = new_endpoint();
210 endpoint.retry = Some(RouteRetry {
211 codes: vec![StatusCode::BAD_REQUEST.as_u16()],
212 attempts: Some(3),
213 backoff: Some(Duration::from_secs(2)),
214 });
215
216 assert!(endpoint.should_retry(HttpResult::StatusFailed));
217 assert!(endpoint.should_retry(HttpResult::StatusError(StatusCode::BAD_REQUEST)));
218 assert!(!endpoint.should_retry(HttpResult::StatusError(StatusCode::SERVICE_UNAVAILABLE)));
219 }
220
221 #[test]
222 fn test_endpoint_should_retry_with_history() {
223 let mut endpoint = new_endpoint();
224 endpoint.retry = Some(RouteRetry {
225 codes: vec![StatusCode::BAD_REQUEST.as_u16()],
226 attempts: Some(3),
227 backoff: Some(Duration::from_secs(2)),
228 });
229
230 assert!(endpoint.should_retry(HttpResult::StatusFailed));
232 assert!(endpoint.should_retry(HttpResult::StatusError(StatusCode::BAD_REQUEST)));
233 assert!(!endpoint.should_retry(HttpResult::StatusError(StatusCode::SERVICE_UNAVAILABLE)));
234
235 endpoint
237 .previous_addrs
238 .push(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 443));
239 assert!(endpoint.should_retry(HttpResult::StatusFailed),);
240 assert!(endpoint.should_retry(HttpResult::StatusError(StatusCode::BAD_REQUEST)),);
241
242 endpoint
244 .previous_addrs
245 .push(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 443));
246 assert!(!endpoint.should_retry(HttpResult::StatusFailed));
247 assert!(!endpoint.should_retry(HttpResult::StatusError(StatusCode::BAD_REQUEST)));
248 }
249
250 fn new_endpoint() -> Endpoint {
251 let url: Url = "http://example.com".parse().unwrap();
252 let backend = Service::dns(url.hostname()).unwrap().as_backend_id(443);
253 let address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 443);
254
255 Endpoint {
256 method: http::Method::GET,
257 url,
258 headers: Default::default(),
259 backend,
260 address,
261 timeouts: None,
262 retry: None,
263 trace: Trace::new(),
264 previous_addrs: vec![],
265 }
266 }
267}