diff --git a/Cargo.lock b/Cargo.lock index 3c47630..aed2a56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -778,6 +778,7 @@ dependencies = [ "hyper", "hyper-util", "log", + "rand", "serde", "serde_json", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 562af72..1a8fe78 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ http-body-util = "0.1" hyper = {version = "1.5.0", features = ["full"]} hyper-util = {version = "0.1", features = ["full"]} log = "0.4.22" +rand = "0.8.5" serde = {version = "1.0.215", features = ["derive"]} serde_json = "1.0.133" tokio = {version = "1.41.1", features = ["full"]} diff --git a/src/health_monitor.rs b/src/health_monitor.rs new file mode 100644 index 0000000..1d87729 --- /dev/null +++ b/src/health_monitor.rs @@ -0,0 +1,233 @@ +use std::sync::atomic::{AtomicU64, AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::{Mutex, OnceCell}; +use tokio::time::sleep; +use log::{info, warn, error}; + +/// Connection health status +#[derive(Debug, Clone)] +pub struct ConnectionHealth { + pub is_healthy: bool, + pub last_success: Option, + pub last_failure: Option, + pub consecutive_failures: u32, + pub total_attempts: u64, + pub total_successes: u64, + pub total_failures: u64, + pub uptime_percentage: f64, +} + +/// Service type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ServiceType { + WebSocket, + WanPolling, +} + +impl std::fmt::Display for ServiceType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ServiceType::WebSocket => write!(f, "WebSocket"), + ServiceType::WanPolling => write!(f, "WAN Polling"), + } + } +} + +/// Health monitor +pub struct HealthMonitor { + websocket_stats: Arc, + wan_polling_stats: Arc, +} + +struct ServiceStats { + is_healthy: AtomicBool, + last_success: Mutex>, + last_failure: Mutex>, + consecutive_failures: AtomicU64, + total_attempts: AtomicU64, + total_successes: AtomicU64, + total_failures: AtomicU64, + start_time: Instant, +} + +impl ServiceStats { + fn new() -> Self { + Self { + is_healthy: AtomicBool::new(false), + last_success: Mutex::new(None), + last_failure: Mutex::new(None), + consecutive_failures: AtomicU64::new(0), + total_attempts: AtomicU64::new(0), + total_successes: AtomicU64::new(0), + total_failures: AtomicU64::new(0), + start_time: Instant::now(), + } + } + + async fn record_attempt(&self) { + self.total_attempts.fetch_add(1, Ordering::Relaxed); + } + + async fn record_success(&self) { + self.is_healthy.store(true, Ordering::Relaxed); + self.consecutive_failures.store(0, Ordering::Relaxed); + self.total_successes.fetch_add(1, Ordering::Relaxed); + *self.last_success.lock().await = Some(Instant::now()); + } + + async fn record_failure(&self) { + self.is_healthy.store(false, Ordering::Relaxed); + self.consecutive_failures.fetch_add(1, Ordering::Relaxed); + self.total_failures.fetch_add(1, Ordering::Relaxed); + *self.last_failure.lock().await = Some(Instant::now()); + } + + async fn get_health(&self) -> ConnectionHealth { + let total_attempts = self.total_attempts.load(Ordering::Relaxed); + let total_successes = self.total_successes.load(Ordering::Relaxed); + let total_failures = self.total_failures.load(Ordering::Relaxed); + + let uptime_percentage = if total_attempts > 0 { + (total_successes as f64 / total_attempts as f64) * 100.0 + } else { + 0.0 + }; + + ConnectionHealth { + is_healthy: self.is_healthy.load(Ordering::Relaxed), + last_success: *self.last_success.lock().await, + last_failure: *self.last_failure.lock().await, + consecutive_failures: self.consecutive_failures.load(Ordering::Relaxed) as u32, + total_attempts, + total_successes, + total_failures, + uptime_percentage, + } + } +} + +impl HealthMonitor { + pub async fn global() -> &'static Self { + static HEALTH_MONITOR: OnceCell = OnceCell::const_new(); + + HEALTH_MONITOR + .get_or_init(|| async { + let monitor = HealthMonitor::new(); + + // 启动健康状态报告任务 + let monitor_clone = monitor.clone(); + tokio::spawn(async move { + monitor_clone.start_health_reporting().await; + }); + + monitor + }) + .await + } + + fn new() -> Self { + Self { + websocket_stats: Arc::new(ServiceStats::new()), + wan_polling_stats: Arc::new(ServiceStats::new()), + } + } + + fn clone(&self) -> Self { + Self { + websocket_stats: self.websocket_stats.clone(), + wan_polling_stats: self.wan_polling_stats.clone(), + } + } + + /// Record service attempt + pub async fn record_attempt(&self, service: ServiceType) { + match service { + ServiceType::WebSocket => self.websocket_stats.record_attempt().await, + ServiceType::WanPolling => self.wan_polling_stats.record_attempt().await, + } + } + + /// Record service success + pub async fn record_success(&self, service: ServiceType) { + match service { + ServiceType::WebSocket => self.websocket_stats.record_success().await, + ServiceType::WanPolling => self.wan_polling_stats.record_success().await, + } + } + + /// Record service failure + pub async fn record_failure(&self, service: ServiceType) { + match service { + ServiceType::WebSocket => self.websocket_stats.record_failure().await, + ServiceType::WanPolling => self.wan_polling_stats.record_failure().await, + } + } + + /// Get service health status + pub async fn get_health(&self, service: ServiceType) -> ConnectionHealth { + match service { + ServiceType::WebSocket => self.websocket_stats.get_health().await, + ServiceType::WanPolling => self.wan_polling_stats.get_health().await, + } + } + + /// Get all services health status + pub async fn get_all_health(&self) -> (ConnectionHealth, ConnectionHealth) { + let websocket_health = self.websocket_stats.get_health().await; + let wan_polling_health = self.wan_polling_stats.get_health().await; + (websocket_health, wan_polling_health) + } + + /// Start health status reporting task + async fn start_health_reporting(&self) { + let mut interval = tokio::time::interval(Duration::from_secs(60)); // Report every minute + + loop { + interval.tick().await; + + let (websocket_health, wan_health) = self.get_all_health().await; + + info!("=== Health Status Report ==="); + self.log_service_health("WebSocket", &websocket_health); + self.log_service_health("WAN Polling", &wan_health); + + // 如果有服务不健康,发出警告 + if !websocket_health.is_healthy { + warn!("WebSocket service is unhealthy! Consecutive failures: {}", websocket_health.consecutive_failures); + } + if !wan_health.is_healthy { + warn!("WAN Polling service is unhealthy! Consecutive failures: {}", wan_health.consecutive_failures); + } + + // 如果连续失败次数过多,发出错误警报 + if websocket_health.consecutive_failures > 10 { + error!("WebSocket service has {} consecutive failures!", websocket_health.consecutive_failures); + } + if wan_health.consecutive_failures > 10 { + error!("WAN Polling service has {} consecutive failures!", wan_health.consecutive_failures); + } + } + } + + fn log_service_health(&self, service_name: &str, health: &ConnectionHealth) { + info!( + "{}: {} | Uptime: {:.1}% | Attempts: {} | Successes: {} | Failures: {} | Consecutive Failures: {}", + service_name, + if health.is_healthy { "HEALTHY" } else { "UNHEALTHY" }, + health.uptime_percentage, + health.total_attempts, + health.total_successes, + health.total_failures, + health.consecutive_failures + ); + + if let Some(last_success) = health.last_success { + info!("{}: Last success: {:.1}s ago", service_name, last_success.elapsed().as_secs_f64()); + } + + if let Some(last_failure) = health.last_failure { + info!("{}: Last failure: {:.1}s ago", service_name, last_failure.elapsed().as_secs_f64()); + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..683cf89 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,6 @@ +pub mod retry; +pub mod health_monitor; + +// 重新导出常用的类型和函数 +pub use retry::{RetryConfig, Retrier, retry_with_config, retry, retry_forever}; +pub use health_monitor::{HealthMonitor, ServiceType, ConnectionHealth}; diff --git a/src/main.rs b/src/main.rs index 917614a..346976b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,9 +2,9 @@ use std::time::Duration; use anyhow::{anyhow, bail}; use clap::Parser; -use futures_util::StreamExt; -use log::{debug, error}; -use tokio::time::sleep; +use futures_util::{SinkExt, StreamExt}; +use log::{debug, error, info, warn}; +use tokio::time::timeout; use tokio_tungstenite::{ connect_async, tungstenite::{ @@ -15,10 +15,14 @@ use tokio_tungstenite::{ }; mod clash_conn_msg; +mod health_monitor; +pub mod retry; mod statistics; mod udp_server; mod wan; +use health_monitor::{HealthMonitor, ServiceType}; +use retry::retry_forever; use wan::poll_wan_traffic; #[derive(Parser, Debug)] @@ -55,42 +59,87 @@ async fn main() { let args = Args::parse(); tokio::spawn(async move { - loop { - if let Err(err) = poll_wan_traffic( + // Use infinite retry mechanism to ensure WAN traffic monitoring is always available + retry_forever(|| async { + let health_monitor = HealthMonitor::global().await; + health_monitor.record_attempt(ServiceType::WanPolling).await; + + match poll_wan_traffic( args.luci_url.as_str(), args.luci_username.as_str(), args.luci_password.as_str(), ) .await { - error!("error: {}", err); + Ok(_) => { + info!("WAN traffic polling ended normally, restarting..."); + health_monitor.record_success(ServiceType::WanPolling).await; + Ok(()) + } + Err(err) => { + error!("WAN traffic polling failed: {}", err); + health_monitor.record_failure(ServiceType::WanPolling).await; + Err(err) + } } - - sleep(Duration::from_secs(1)).await; - error!("restart poll_wan_traffic!"); - } + }).await; }); let connect_addr = args.clash_url; - loop { - if let Err(err) = pipe(connect_addr.clone()).await { - error!("{}", err); - }; - sleep(Duration::from_secs(1)).await; - error!("restart clash!"); - } + // Use infinite retry mechanism to ensure WebSocket connection is always available + retry_forever(|| async { + let health_monitor = HealthMonitor::global().await; + health_monitor.record_attempt(ServiceType::WebSocket).await; + + match pipe(connect_addr.clone()).await { + Ok(_) => { + info!("WebSocket connection ended normally, reconnecting..."); + health_monitor.record_success(ServiceType::WebSocket).await; + Ok(()) + } + Err(err) => { + error!("WebSocket connection failed: {}", err); + health_monitor.record_failure(ServiceType::WebSocket).await; + Err(err) + } + } + }).await; } async fn pipe(connect_addr: String) -> anyhow::Result<()> { + info!("Attempting to connect to WebSocket: {}", connect_addr); let request = connect_addr.into_client_request()?; - let (mut ws_stream, _) = connect_async(request).await.map_err(|err| anyhow::anyhow!(err))?; - println!("WebSocket handshake has been successfully completed"); + // Add connection timeout + let (mut ws_stream, _) = timeout(Duration::from_secs(30), connect_async(request)) + .await + .map_err(|_| anyhow!("WebSocket connection timeout after 30 seconds"))? + .map_err(|err| anyhow!("WebSocket connection failed: {}", err))?; - while let Some(message) = ws_stream.next().await { - let message = message?; + info!("WebSocket handshake completed successfully"); + + let heartbeat_timeout = Duration::from_secs(60); // Consider connection abnormal if no message for 60 seconds + + loop { + let message_result = timeout(heartbeat_timeout, ws_stream.next()).await; + + let message = match message_result { + Ok(Some(Ok(msg))) => msg, + Ok(Some(Err(err))) => { + error!("WebSocket message error: {}", err); + return Err(anyhow!("WebSocket message error: {}", err)); + } + Ok(None) => { + warn!("WebSocket stream ended unexpectedly"); + return Err(anyhow!("WebSocket stream ended")); + } + Err(_) => { + error!("WebSocket heartbeat timeout - no message received in {} seconds", heartbeat_timeout.as_secs()); + return Err(anyhow!("WebSocket heartbeat timeout")); + } + }; match message { Message::Text(text) => { @@ -136,9 +185,24 @@ async fn pipe(connect_addr: String) -> anyhow::Result<()> { buf[24..32].copy_from_slice(&state.proxy_download_speed.to_le_bytes()); udp_server.publish_clash(&buf).await; } - Message::Close(_) => { - println!("Server requested close"); - break; + Message::Close(close_frame) => { + if let Some(frame) = close_frame { + info!("Server requested close: {} - {}", frame.code, frame.reason); + } else { + info!("Server requested close without reason"); + } + return Err(anyhow!("WebSocket connection closed by server")); + } + Message::Ping(payload) => { + debug!("Received ping, sending pong"); + // Automatically reply pong to keep connection alive + if let Err(err) = ws_stream.send(Message::Pong(payload)).await { + error!("Failed to send pong: {}", err); + return Err(anyhow!("Failed to send pong: {}", err)); + } + } + Message::Pong(_) => { + debug!("Received pong"); } _ => {} } diff --git a/src/retry.rs b/src/retry.rs new file mode 100644 index 0000000..58bc431 --- /dev/null +++ b/src/retry.rs @@ -0,0 +1,261 @@ +use std::time::Duration; +use log::{debug, warn, error}; +use tokio::time::sleep; + +/// Retry strategy configuration +#[derive(Clone, Debug)] +pub struct RetryConfig { + /// Maximum number of retry attempts + pub max_attempts: usize, + /// Initial retry delay + pub initial_delay: Duration, + /// Maximum retry delay + pub max_delay: Duration, + /// Exponential backoff multiplier + pub backoff_multiplier: f64, + /// Whether to enable jitter + pub jitter: bool, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_attempts: 10, + initial_delay: Duration::from_millis(500), + max_delay: Duration::from_secs(30), + backoff_multiplier: 2.0, + jitter: true, + } + } +} + +impl RetryConfig { + /// Create fast retry configuration (for lightweight operations) + pub fn fast() -> Self { + Self { + max_attempts: 5, + initial_delay: Duration::from_millis(100), + max_delay: Duration::from_secs(5), + backoff_multiplier: 1.5, + jitter: true, + } + } + + /// Create slow retry configuration (for heavyweight operations) + pub fn slow() -> Self { + Self { + max_attempts: 15, + initial_delay: Duration::from_secs(1), + max_delay: Duration::from_secs(60), + backoff_multiplier: 2.0, + jitter: true, + } + } + + /// Create infinite retry configuration (for critical services) + pub fn infinite() -> Self { + Self { + max_attempts: usize::MAX, + initial_delay: Duration::from_secs(1), + max_delay: Duration::from_secs(30), + backoff_multiplier: 2.0, + jitter: true, + } + } +} + +/// Retrier implementation +pub struct Retrier { + config: RetryConfig, + attempt: usize, + current_delay: Duration, +} + +impl Retrier { + pub fn new(config: RetryConfig) -> Self { + Self { + current_delay: config.initial_delay, + config, + attempt: 0, + } + } + + /// Execute retry operation + pub async fn retry(&mut self, operation: F) -> Result + where + F: Fn() -> Fut, + Fut: std::future::Future>, + E: std::fmt::Display, + { + loop { + self.attempt += 1; + + debug!("Attempting operation (attempt {}/{})", self.attempt, self.config.max_attempts); + + match operation().await { + Ok(result) => { + if self.attempt > 1 { + debug!("Operation succeeded after {} attempts", self.attempt); + } + return Ok(result); + } + Err(err) => { + if self.attempt >= self.config.max_attempts { + error!("Operation failed after {} attempts: {}", self.attempt, err); + return Err(err); + } + + warn!("Operation failed (attempt {}/{}): {}", self.attempt, self.config.max_attempts, err); + + let delay = self.calculate_delay(); + debug!("Retrying in {:?}", delay); + sleep(delay).await; + } + } + } + } + + /// Reset retrier state + pub fn reset(&mut self) { + self.attempt = 0; + self.current_delay = self.config.initial_delay; + } + + /// Calculate next retry delay + fn calculate_delay(&mut self) -> Duration { + let delay = self.current_delay; + + // Calculate next delay time (exponential backoff) + let next_delay_ms = (self.current_delay.as_millis() as f64 * self.config.backoff_multiplier) as u64; + self.current_delay = Duration::from_millis(next_delay_ms).min(self.config.max_delay); + + // Add jitter to avoid thundering herd effect + if self.config.jitter { + let jitter_range = delay.as_millis() as f64 * 0.1; // 10% jitter + let jitter = (rand::random::() - 0.5) * 2.0 * jitter_range; + let jittered_delay = (delay.as_millis() as f64 + jitter).max(0.0) as u64; + Duration::from_millis(jittered_delay) + } else { + delay + } + } + + /// Get current attempt count + pub fn attempt_count(&self) -> usize { + self.attempt + } +} + +/// Convenience function: execute operation with retry +pub async fn retry_with_config( + config: RetryConfig, + operation: F, +) -> Result +where + F: Fn() -> Fut, + Fut: std::future::Future>, + E: std::fmt::Display, +{ + let mut retrier = Retrier::new(config); + retrier.retry(operation).await +} + +/// Convenience function: execute retry operation with default configuration +pub async fn retry(operation: F) -> Result +where + F: Fn() -> Fut, + Fut: std::future::Future>, + E: std::fmt::Display, +{ + retry_with_config(RetryConfig::default(), operation).await +} + +/// Convenience function: infinite retry (for critical services) +pub async fn retry_forever(operation: F) -> T +where + F: Fn() -> Fut, + Fut: std::future::Future>, + E: std::fmt::Display, +{ + let mut retrier = Retrier::new(RetryConfig::infinite()); + loop { + match retrier.retry(&operation).await { + Ok(result) => return result, + Err(_) => { + // This should theoretically never happen since we set infinite retry + // But for safety, we reset the retrier and continue + retrier.reset(); + continue; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + #[tokio::test] + async fn test_retry_success_on_first_attempt() { + let counter = Arc::new(AtomicUsize::new(0)); + let counter_clone = counter.clone(); + + let result = retry(|| async { + counter_clone.fetch_add(1, Ordering::SeqCst); + Ok::(42) + }).await; + + assert_eq!(result.unwrap(), 42); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_retry_success_after_failures() { + let counter = Arc::new(AtomicUsize::new(0)); + let counter_clone = counter.clone(); + + let config = RetryConfig { + max_attempts: 3, + initial_delay: Duration::from_millis(1), + max_delay: Duration::from_millis(10), + backoff_multiplier: 2.0, + jitter: false, + }; + + let result = retry_with_config(config, || async { + let count = counter_clone.fetch_add(1, Ordering::SeqCst) + 1; + if count < 3 { + Err("not ready") + } else { + Ok(42) + } + }).await; + + assert_eq!(result.unwrap(), 42); + assert_eq!(counter.load(Ordering::SeqCst), 3); + } + + #[tokio::test] + async fn test_retry_max_attempts_exceeded() { + let counter = Arc::new(AtomicUsize::new(0)); + let counter_clone = counter.clone(); + + let config = RetryConfig { + max_attempts: 2, + initial_delay: Duration::from_millis(1), + max_delay: Duration::from_millis(10), + backoff_multiplier: 2.0, + jitter: false, + }; + + let result = retry_with_config(config, || async { + counter_clone.fetch_add(1, Ordering::SeqCst); + Err::("always fails") + }).await; + + assert!(result.is_err()); + assert_eq!(counter.load(Ordering::SeqCst), 2); + } +} diff --git a/src/wan.rs b/src/wan.rs index e7e5549..c06fbb7 100644 --- a/src/wan.rs +++ b/src/wan.rs @@ -11,16 +11,18 @@ use hyper::{ Request, }; use hyper_util::rt::TokioIo; -use log::debug; +use log::{debug, error, info, warn}; use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; -use tokio::time::sleep; +use tokio::time::{sleep, timeout}; use crate::udp_server; pub type OpenWRTIFaces = Vec; pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) -> Result<()> { + info!("Starting WAN traffic polling for: {}", luci_url); + // Parse our URL... let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::()?; @@ -29,31 +31,44 @@ pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) -> let port = url.port_u16().unwrap_or(80); let address = format!("{}:{}", host, port); + info!("Connecting to: {}", address); let mut cookies_str = String::new(); - // Open a TCP connection to the remote host + let connection_timeout = Duration::from_secs(10); + let request_timeout = Duration::from_secs(30); + + // Open a TCP connection to the remote host with timeout loop { - let stream = TcpStream::connect(address.clone()).await?; + let stream = timeout(connection_timeout, TcpStream::connect(address.clone())) + .await + .map_err(|_| anyhow!("Connection timeout after {} seconds", connection_timeout.as_secs()))? + .map_err(|e| anyhow!("Failed to connect to {}: {}", address, e))?; // Use an adapter to access something implementing `tokio::io` traits as if they implement // `hyper::rt` IO traits. let io = TokioIo::new(stream); - // Create the Hyper client - let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + // Create the Hyper client with timeout + let (mut sender, conn) = timeout(connection_timeout, hyper::client::conn::http1::handshake(io)) + .await + .map_err(|_| anyhow!("HTTP handshake timeout after {} seconds", connection_timeout.as_secs()))? + .map_err(|e| anyhow!("HTTP handshake failed: {}", e))?; // Spawn a task to poll the connection, driving the HTTP state - tokio::task::spawn(async move { + let conn_handle = tokio::task::spawn(async move { if let Err(err) = conn.await { - println!("Connection failed: {:?}", err); + error!("HTTP connection failed: {:?}", err); } }); let mut prev_up = 0; let mut prev_down = 0; + let mut consecutive_errors = 0; + const MAX_CONSECUTIVE_ERRORS: u32 = 5; loop { if sender.is_closed() { + warn!("HTTP sender is closed, breaking connection loop"); break; } @@ -64,29 +79,52 @@ pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) -> let poll_req = generate_poll_wan_req(luci_url, &cookies_str)?; - // Await the response... - - let res = sender.send_request(poll_req.clone()).await?; + // Send request with timeout + let res = timeout(request_timeout, sender.send_request(poll_req.clone())) + .await + .map_err(|_| anyhow!("Request timeout after {} seconds", request_timeout.as_secs()))? + .map_err(|e| anyhow!("Request failed: {}", e))?; debug!("Response status: {}", res.status()); match res.status() { StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => { + info!("Authentication required, logging in..."); cookies_str = login(luci_url, username, password).await?; + consecutive_errors = 0; // Reset error count + continue; + } + StatusCode::OK => { + consecutive_errors = 0; // Reset error count + } + status => { + consecutive_errors += 1; + error!("Unexpected response status: {} (consecutive errors: {})", status, consecutive_errors); + if consecutive_errors >= MAX_CONSECUTIVE_ERRORS { + return Err(anyhow!("Too many consecutive errors ({}), giving up", consecutive_errors)); + } + sleep(Duration::from_secs(1)).await; continue; } - _ => {} } - // asynchronously aggregate the chunks of the body - let body = res.collect().await?.aggregate(); + // asynchronously aggregate the chunks of the body with timeout + let body = timeout(request_timeout, res.collect()) + .await + .map_err(|_| anyhow!("Response body timeout after {} seconds", request_timeout.as_secs()))? + .map_err(|e| anyhow!("Failed to read response body: {}", e))? + .aggregate(); // try to parse as json with serde_json - let interfaces: OpenWRTIFaces = serde_json::from_reader(body.reader())?; + let interfaces: OpenWRTIFaces = serde_json::from_reader(body.reader()) + .map_err(|e| anyhow!("Failed to parse JSON response: {}", e))?; - let wan = interfaces.first().unwrap(); + let wan = interfaces.first() + .ok_or_else(|| anyhow!("No WAN interface found in response"))?; + // Detect counter reset (restart scenarios) if prev_up > wan.tx_bytes || prev_down > wan.rx_bytes { + info!("Counter reset detected, resetting baseline values"); prev_up = wan.tx_bytes; prev_down = wan.rx_bytes; } @@ -96,48 +134,81 @@ pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) -> prev_up = wan.tx_bytes; prev_down = wan.rx_bytes; + debug!("WAN traffic: ↑ {} bytes/s, ↓ {} bytes/s", up_speed, down_speed); + let udp_server = udp_server::UdpServer::global().await; let mut buf = [0; 16]; buf[0..8].copy_from_slice(&up_speed.to_le_bytes()); buf[8..16].copy_from_slice(&down_speed.to_le_bytes()); udp_server.publish_wan(&buf).await; - // println!("speed: ↑ {}, ↓ {}", up_speed, down_speed); - sleep(Duration::from_secs(1)).await; } } } async fn login(luci_url: &str, username: &str, password: &str) -> Result { + info!("Attempting to login to: {}", luci_url); + let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::()?; let host = url.host().expect("uri has no host"); let port = url.port_u16().unwrap_or(80); let address = format!("{}:{}", host, port); - let stream = TcpStream::connect(address).await?; + let connection_timeout = Duration::from_secs(10); + let request_timeout = Duration::from_secs(30); + + let stream = timeout(connection_timeout, TcpStream::connect(address.clone())) + .await + .map_err(|_| anyhow!("Login connection timeout after {} seconds", connection_timeout.as_secs()))? + .map_err(|e| anyhow!("Failed to connect for login to {}: {}", address, e))?; + let io = TokioIo::new(stream); - let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; - tokio::task::spawn(async move { + let (mut sender, conn) = timeout(connection_timeout, hyper::client::conn::http1::handshake(io)) + .await + .map_err(|_| anyhow!("Login handshake timeout after {} seconds", connection_timeout.as_secs()))? + .map_err(|e| anyhow!("Login handshake failed: {}", e))?; + + let _conn_handle = tokio::task::spawn(async move { if let Err(err) = conn.await { - println!("Connection failed: {:?}", err); + error!("Login connection failed: {:?}", err); } }); let login_req = generate_login_req(luci_url, username, password)?; - let res = sender.send_request(login_req.clone()).await?; + let res = timeout(request_timeout, sender.send_request(login_req.clone())) + .await + .map_err(|_| anyhow!("Login request timeout after {} seconds", request_timeout.as_secs()))? + .map_err(|e| anyhow!("Login request failed: {}", e))?; - if res.status() == StatusCode::FORBIDDEN { - bail!("Login failed, got status: {}", res.status()); + match res.status() { + StatusCode::OK | StatusCode::FOUND | StatusCode::SEE_OTHER => { + info!("Login successful"); + } + StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => { + error!("Login failed: Invalid credentials"); + bail!("Login failed with status: {}", res.status()); + } + status => { + error!("Login failed with unexpected status: {}", status); + bail!("Login failed with status: {}", status); + } } let cookies = res.headers().get_all(hyper::header::SET_COOKIE); let cookies = cookies .iter() - .map(|cookie| cookie.to_str().unwrap().split(';').next()); - let cookies = cookies.filter_map(|cookie| cookie); - let cookies_str = cookies.collect::>().join("; "); + .filter_map(|cookie| cookie.to_str().ok()) + .filter_map(|cookie| cookie.split(';').next()) + .collect::>(); + + if cookies.is_empty() { + warn!("No cookies received from login response"); + } + + let cookies_str = cookies.join("; "); + debug!("Received cookies: {}", cookies_str); Ok(cookies_str) } diff --git a/tests/integration_test.rs b/tests/integration_test.rs new file mode 100644 index 0000000..fd1d1d9 --- /dev/null +++ b/tests/integration_test.rs @@ -0,0 +1,148 @@ +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; +use network_monitor::retry::{RetryConfig, retry_with_config}; + +/// 模拟网络故障的测试 +#[tokio::test] +async fn test_network_failure_recovery() { + // 模拟一个会失败几次然后成功的操作 + let attempt_count = Arc::new(AtomicU32::new(0)); + let max_failures = 3; + + let config = RetryConfig { + max_attempts: 10, + initial_delay: Duration::from_millis(10), + max_delay: Duration::from_millis(100), + backoff_multiplier: 1.5, + jitter: false, // 关闭抖动以便测试更可预测 + }; + + let attempt_count_clone = attempt_count.clone(); + let result = retry_with_config(config, move || { + let attempt_count = attempt_count_clone.clone(); + async move { + let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1; + + if current_attempt <= max_failures { + // 模拟网络错误 + Err(format!("Network error on attempt {}", current_attempt)) + } else { + // 模拟恢复成功 + Ok(format!("Success on attempt {}", current_attempt)) + } + } + }).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "Success on attempt 4"); + assert_eq!(attempt_count.load(Ordering::SeqCst), 4); +} + +/// 测试连接超时场景 +#[tokio::test] +async fn test_connection_timeout_scenario() { + let config = RetryConfig { + max_attempts: 3, + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(20), + backoff_multiplier: 2.0, + jitter: false, + }; + + let attempt_count = Arc::new(AtomicU32::new(0)); + + let attempt_count_clone = attempt_count.clone(); + let result: Result = retry_with_config(config, move || { + let attempt_count = attempt_count_clone.clone(); + async move { + attempt_count.fetch_add(1, Ordering::SeqCst); + + // 模拟连接超时 + sleep(Duration::from_millis(1)).await; + Err("Connection timeout") + } + }).await; + + assert!(result.is_err()); + assert_eq!(attempt_count.load(Ordering::SeqCst), 3); // 应该尝试了3次 +} + +/// 测试快速恢复场景 +#[tokio::test] +async fn test_fast_recovery() { + let config = RetryConfig::fast(); + + let attempt_count = Arc::new(AtomicU32::new(0)); + + let attempt_count_clone = attempt_count.clone(); + let result = retry_with_config(config, move || { + let attempt_count = attempt_count_clone.clone(); + async move { + let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1; + + if current_attempt == 1 { + Err("Temporary failure") + } else { + Ok("Quick recovery") + } + } + }).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "Quick recovery"); + assert_eq!(attempt_count.load(Ordering::SeqCst), 2); +} + +/// 测试慢速重试场景 +#[tokio::test] +async fn test_slow_retry_scenario() { + let config = RetryConfig::slow(); + + let attempt_count = Arc::new(AtomicU32::new(0)); + + let attempt_count_clone = attempt_count.clone(); + let result = retry_with_config(config, move || { + let attempt_count = attempt_count_clone.clone(); + async move { + let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1; + + if current_attempt <= 2 { + Err("Service unavailable") + } else { + Ok("Service restored") + } + } + }).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "Service restored"); + assert_eq!(attempt_count.load(Ordering::SeqCst), 3); +} + +/// 测试最大重试次数限制 +#[tokio::test] +async fn test_max_retry_limit() { + let config = RetryConfig { + max_attempts: 2, + initial_delay: Duration::from_millis(1), + max_delay: Duration::from_millis(5), + backoff_multiplier: 2.0, + jitter: false, + }; + + let attempt_count = Arc::new(AtomicU32::new(0)); + + let attempt_count_clone = attempt_count.clone(); + let result = retry_with_config(config, move || { + let attempt_count = attempt_count_clone.clone(); + async move { + attempt_count.fetch_add(1, Ordering::SeqCst); + Err::("Persistent failure") + } + }).await; + + assert!(result.is_err()); + assert_eq!(attempt_count.load(Ordering::SeqCst), 2); +}