1use crate::service::{Service, ServiceInstance, ServiceHealth};
3use crate::metrics;
4use anyhow::{Result, Context};
5use reqwest::Client;
6use std::net::SocketAddr;
7use std::time::Duration;
8use std::fmt;
9use tracing::{debug, info, warn, error};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum HealthCheckType {
14 Http,
16
17 Tcp,
19
20 Script,
22
23 Ttl,
25}
26
27impl fmt::Display for HealthCheckType {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 match self {
30 HealthCheckType::Http => write!(f, "http"),
31 HealthCheckType::Tcp => write!(f, "tcp"),
32 HealthCheckType::Script => write!(f, "script"),
33 HealthCheckType::Ttl => write!(f, "ttl"),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct HealthCheckConfig {
41 pub check_type: HealthCheckType,
43
44 pub interval_secs: u64,
46
47 pub timeout_ms: u64,
49
50 pub http_path: Option<String>,
52
53 pub http_expected_codes: Option<Vec<u16>>,
55
56 pub http_expected_body: Option<String>,
58
59 pub http_headers: Option<Vec<(String, String)>>,
61
62 pub deregister_on_failure: bool,
64
65 pub failure_threshold: u32,
67
68 pub success_threshold: u32,
70}
71
72impl Default for HealthCheckConfig {
73 fn default() -> Self {
74 Self {
75 check_type: HealthCheckType::Http,
76 interval_secs: 10,
77 timeout_ms: 2000,
78 http_path: Some("/health".to_string()),
79 http_expected_codes: Some(vec![200]),
80 http_expected_body: None,
81 http_headers: None,
82 deregister_on_failure: true,
83 failure_threshold: 3,
84 success_threshold: 1,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct HealthCheckResult {
92 pub instance_id: String,
94
95 pub service_name: String,
97
98 pub check_type: HealthCheckType,
100
101 pub is_healthy: bool,
103
104 pub output: String,
106
107 pub duration_ms: u64,
109
110 pub timestamp: chrono::DateTime<chrono::Utc>,
112}
113
114pub struct HealthChecker {
116 http_client: Client,
118
119 default_config: HealthCheckConfig,
121}
122
123impl HealthChecker {
124 pub fn new(default_config: Option<HealthCheckConfig>) -> Self {
126 let http_client = Client::builder()
127 .timeout(Duration::from_millis(30000)) .build()
129 .expect("Failed to create HTTP client");
130
131 Self {
132 http_client,
133 default_config: default_config.unwrap_or_default(),
134 }
135 }
136
137 pub async fn check_instance(&self, instance: &ServiceInstance) -> Result<HealthCheckResult> {
139 let start = std::time::Instant::now();
140 let timestamp = chrono::Utc::now();
141
142 let check_type = match instance.protocol.as_str() {
144 "http" | "https" => HealthCheckType::Http,
145 "tcp" | "udp" => HealthCheckType::Tcp,
146 _ => HealthCheckType::Ttl,
147 };
148
149 let (is_healthy, output) = match check_type {
151 HealthCheckType::Http => {
152 self.perform_http_check(instance).await?
153 }
154 HealthCheckType::Tcp => {
155 self.perform_tcp_check(instance).await?
156 }
157 HealthCheckType::Script => {
158 self.perform_script_check(instance).await?
159 }
160 HealthCheckType::Ttl => {
161 self.perform_ttl_check(instance)?
162 }
163 };
164
165 let duration_ms = start.elapsed().as_millis() as u64;
166
167 metrics::record_service_health_check(&instance.service_name, is_healthy);
169
170 Ok(HealthCheckResult {
171 instance_id: instance.id.clone(),
172 service_name: instance.service_name.clone(),
173 check_type,
174 is_healthy,
175 output,
176 duration_ms,
177 timestamp,
178 })
179 }
180
181 async fn perform_http_check(&self, instance: &ServiceInstance) -> Result<(bool, String)> {
183 let health_path = instance.health_check_path.as_deref()
185 .unwrap_or(&self.default_config.http_path.as_deref().unwrap_or("/health"));
186
187 let url = format!("{}://{}:{}{}",
189 instance.protocol,
190 instance.host,
191 instance.port,
192 health_path
193 );
194
195 let timeout = Duration::from_millis(
197 self.default_config.timeout_ms
198 );
199
200 let mut req_builder = self.http_client.get(&url).timeout(timeout);
202
203 if let Some(headers) = &self.default_config.http_headers {
205 for (name, value) in headers {
206 req_builder = req_builder.header(name, value);
207 }
208 }
209
210 let response = match req_builder.send().await {
212 Ok(resp) => resp,
213 Err(e) => {
214 return Ok((false, format!("HTTP request failed: {}", e)));
215 }
216 };
217
218 let status = response.status();
220 let expected_codes = self.default_config.http_expected_codes.as_ref()
221 .map(|codes| codes.as_slice())
222 .unwrap_or(&[200]);
223
224 let status_ok = expected_codes.contains(&status.as_u16());
225
226 let body_ok = if let Some(expected_body) = &self.default_config.http_expected_body {
228 let body = response.text().await?;
229 body.contains(expected_body)
230 } else {
231 true
232 };
233
234 let is_healthy = status_ok && body_ok;
236
237 let output = if is_healthy {
238 format!("HTTP check passed with status {}", status)
239 } else if !status_ok {
240 format!("HTTP check failed: expected status code {:?}, got {}", expected_codes, status)
241 } else {
242 "HTTP check failed: expected body content not found".to_string()
243 };
244
245 Ok((is_healthy, output))
246 }
247
248 async fn perform_tcp_check(&self, instance: &ServiceInstance) -> Result<(bool, String)> {
250 let addr = format!("{}:{}", instance.host, instance.port);
252 let addr = addr.parse::<SocketAddr>()
253 .with_context(|| format!("Invalid socket address: {}", addr))?;
254
255 let timeout_duration = Duration::from_millis(
257 self.default_config.timeout_ms
258 );
259
260 match tokio::time::timeout(timeout_duration, tokio::net::TcpStream::connect(addr)).await {
262 Ok(Ok(_)) => {
263 Ok((true, format!("TCP connection successful to {}:{}", instance.host, instance.port)))
264 }
265 Ok(Err(e)) => {
266 Ok((false, format!("TCP connection failed: {}", e)))
267 }
268 Err(_) => {
269 Ok((false, format!("TCP connection timed out after {}ms", self.default_config.timeout_ms)))
270 }
271 }
272 }
273
274 async fn perform_script_check(&self, instance: &ServiceInstance) -> Result<(bool, String)> {
276 if let Some(script) = instance.metadata.get("health_check_script") {
281 Ok((true, format!("Script health check would run: {}", script)))
285 } else {
286 Ok((false, "No health check script defined".to_string()))
287 }
288 }
289
290 fn perform_ttl_check(&self, instance: &ServiceInstance) -> Result<(bool, String)> {
292 if let Some(ttl) = instance.ttl {
294 if let Some(last_heartbeat) = instance.last_heartbeat {
296 let now = chrono::Utc::now();
297 let duration = now.signed_duration_since(last_heartbeat);
298
299 let is_healthy = duration.num_seconds() < ttl as i64;
300
301 if is_healthy {
302 Ok((true, format!("TTL check passed: last heartbeat {} seconds ago", duration.num_seconds())))
303 } else {
304 Ok((false, format!("TTL check failed: last heartbeat {} seconds ago, TTL is {} seconds",
305 duration.num_seconds(), ttl)))
306 }
307 } else {
308 Ok((false, "TTL check failed: no heartbeat recorded".to_string()))
309 }
310 } else {
311 Ok((true, "No TTL defined, assuming healthy".to_string()))
313 }
314 }
315}
316
317pub struct HealthManager {
319 checker: HealthChecker,
321
322 check_counts: dashmap::DashMap<String, (u32, u32)>, config: HealthCheckConfig,
327}
328
329impl HealthManager {
330 pub fn new(config: Option<HealthCheckConfig>) -> Self {
332 let config = config.unwrap_or_default();
333
334 Self {
335 checker: HealthChecker::new(Some(config.clone())),
336 check_counts: dashmap::DashMap::new(),
337 config,
338 }
339 }
340
341 pub async fn check_and_update(&self, instance: &ServiceInstance) -> Result<(ServiceHealth, HealthCheckResult)> {
343 let result = self.checker.check_instance(instance).await?;
345
346 let mut counts = self.check_counts
348 .entry(instance.id.clone())
349 .or_insert((0, 0));
350
351 let (failures, successes) = if result.is_healthy {
353 (0, counts.1 + 1)
355 } else {
356 (counts.0 + 1, 0)
358 };
359
360 *counts = (failures, successes);
361
362 let new_health = if failures >= self.config.failure_threshold {
364 ServiceHealth::Unhealthy
366 } else if successes >= self.config.success_threshold {
367 ServiceHealth::Healthy
369 } else {
370 instance.health
372 };
373
374 Ok((new_health, result))
375 }
376
377 pub fn reset_counts(&self, instance_id: &str) {
379 self.check_counts.remove(instance_id);
380 }
381}