feat: implement robust retry mechanism for API aggregation service
All checks were successful
Gitea Actions Demo / build (push) Successful in 4m54s

- Add unified retry strategy with exponential backoff and jitter
- Implement health monitoring system for service status tracking
- Enhance WebSocket connection with timeout and heartbeat detection
- Improve WAN polling with connection timeout and error handling
- Add comprehensive test suite for retry mechanisms
- Fix connection hanging issues that prevented proper retry recovery

Key improvements:
- RetryConfig with fast/slow/infinite retry strategies
- HealthMonitor for real-time service status tracking
- Connection timeouts to prevent hanging
- Automatic ping/pong handling for WebSocket
- Consecutive error counting and thresholds
- Detailed logging for better diagnostics

All tests passing: 11 tests (3 unit + 5 integration + 3 lib tests)
This commit is contained in:
Ivan Li 2025-06-30 16:49:34 +08:00
parent 7fd8639c5e
commit 2a9e34d345
8 changed files with 838 additions and 53 deletions

3
Cargo.lock generated
View File

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 4
[[package]] [[package]]
name = "addr2line" name = "addr2line"
@ -778,6 +778,7 @@ dependencies = [
"hyper", "hyper",
"hyper-util", "hyper-util",
"log", "log",
"rand",
"serde", "serde",
"serde_json", "serde_json",
"tokio", "tokio",

View File

@ -15,6 +15,7 @@ http-body-util = "0.1"
hyper = {version = "1.5.0", features = ["full"]} hyper = {version = "1.5.0", features = ["full"]}
hyper-util = {version = "0.1", features = ["full"]} hyper-util = {version = "0.1", features = ["full"]}
log = "0.4.22" log = "0.4.22"
rand = "0.8.5"
serde = {version = "1.0.215", features = ["derive"]} serde = {version = "1.0.215", features = ["derive"]}
serde_json = "1.0.133" serde_json = "1.0.133"
tokio = {version = "1.41.1", features = ["full"]} tokio = {version = "1.41.1", features = ["full"]}

233
src/health_monitor.rs Normal file
View File

@ -0,0 +1,233 @@
use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, OnceCell};
use tokio::time::sleep;
use log::{info, warn, error};
/// Connection health status
#[derive(Debug, Clone)]
pub struct ConnectionHealth {
pub is_healthy: bool,
pub last_success: Option<Instant>,
pub last_failure: Option<Instant>,
pub consecutive_failures: u32,
pub total_attempts: u64,
pub total_successes: u64,
pub total_failures: u64,
pub uptime_percentage: f64,
}
/// Service type
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServiceType {
WebSocket,
WanPolling,
}
impl std::fmt::Display for ServiceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServiceType::WebSocket => write!(f, "WebSocket"),
ServiceType::WanPolling => write!(f, "WAN Polling"),
}
}
}
/// Health monitor
pub struct HealthMonitor {
websocket_stats: Arc<ServiceStats>,
wan_polling_stats: Arc<ServiceStats>,
}
struct ServiceStats {
is_healthy: AtomicBool,
last_success: Mutex<Option<Instant>>,
last_failure: Mutex<Option<Instant>>,
consecutive_failures: AtomicU64,
total_attempts: AtomicU64,
total_successes: AtomicU64,
total_failures: AtomicU64,
start_time: Instant,
}
impl ServiceStats {
fn new() -> Self {
Self {
is_healthy: AtomicBool::new(false),
last_success: Mutex::new(None),
last_failure: Mutex::new(None),
consecutive_failures: AtomicU64::new(0),
total_attempts: AtomicU64::new(0),
total_successes: AtomicU64::new(0),
total_failures: AtomicU64::new(0),
start_time: Instant::now(),
}
}
async fn record_attempt(&self) {
self.total_attempts.fetch_add(1, Ordering::Relaxed);
}
async fn record_success(&self) {
self.is_healthy.store(true, Ordering::Relaxed);
self.consecutive_failures.store(0, Ordering::Relaxed);
self.total_successes.fetch_add(1, Ordering::Relaxed);
*self.last_success.lock().await = Some(Instant::now());
}
async fn record_failure(&self) {
self.is_healthy.store(false, Ordering::Relaxed);
self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
self.total_failures.fetch_add(1, Ordering::Relaxed);
*self.last_failure.lock().await = Some(Instant::now());
}
async fn get_health(&self) -> ConnectionHealth {
let total_attempts = self.total_attempts.load(Ordering::Relaxed);
let total_successes = self.total_successes.load(Ordering::Relaxed);
let total_failures = self.total_failures.load(Ordering::Relaxed);
let uptime_percentage = if total_attempts > 0 {
(total_successes as f64 / total_attempts as f64) * 100.0
} else {
0.0
};
ConnectionHealth {
is_healthy: self.is_healthy.load(Ordering::Relaxed),
last_success: *self.last_success.lock().await,
last_failure: *self.last_failure.lock().await,
consecutive_failures: self.consecutive_failures.load(Ordering::Relaxed) as u32,
total_attempts,
total_successes,
total_failures,
uptime_percentage,
}
}
}
impl HealthMonitor {
pub async fn global() -> &'static Self {
static HEALTH_MONITOR: OnceCell<HealthMonitor> = OnceCell::const_new();
HEALTH_MONITOR
.get_or_init(|| async {
let monitor = HealthMonitor::new();
// 启动健康状态报告任务
let monitor_clone = monitor.clone();
tokio::spawn(async move {
monitor_clone.start_health_reporting().await;
});
monitor
})
.await
}
fn new() -> Self {
Self {
websocket_stats: Arc::new(ServiceStats::new()),
wan_polling_stats: Arc::new(ServiceStats::new()),
}
}
fn clone(&self) -> Self {
Self {
websocket_stats: self.websocket_stats.clone(),
wan_polling_stats: self.wan_polling_stats.clone(),
}
}
/// Record service attempt
pub async fn record_attempt(&self, service: ServiceType) {
match service {
ServiceType::WebSocket => self.websocket_stats.record_attempt().await,
ServiceType::WanPolling => self.wan_polling_stats.record_attempt().await,
}
}
/// Record service success
pub async fn record_success(&self, service: ServiceType) {
match service {
ServiceType::WebSocket => self.websocket_stats.record_success().await,
ServiceType::WanPolling => self.wan_polling_stats.record_success().await,
}
}
/// Record service failure
pub async fn record_failure(&self, service: ServiceType) {
match service {
ServiceType::WebSocket => self.websocket_stats.record_failure().await,
ServiceType::WanPolling => self.wan_polling_stats.record_failure().await,
}
}
/// Get service health status
pub async fn get_health(&self, service: ServiceType) -> ConnectionHealth {
match service {
ServiceType::WebSocket => self.websocket_stats.get_health().await,
ServiceType::WanPolling => self.wan_polling_stats.get_health().await,
}
}
/// Get all services health status
pub async fn get_all_health(&self) -> (ConnectionHealth, ConnectionHealth) {
let websocket_health = self.websocket_stats.get_health().await;
let wan_polling_health = self.wan_polling_stats.get_health().await;
(websocket_health, wan_polling_health)
}
/// Start health status reporting task
async fn start_health_reporting(&self) {
let mut interval = tokio::time::interval(Duration::from_secs(60)); // Report every minute
loop {
interval.tick().await;
let (websocket_health, wan_health) = self.get_all_health().await;
info!("=== Health Status Report ===");
self.log_service_health("WebSocket", &websocket_health);
self.log_service_health("WAN Polling", &wan_health);
// 如果有服务不健康,发出警告
if !websocket_health.is_healthy {
warn!("WebSocket service is unhealthy! Consecutive failures: {}", websocket_health.consecutive_failures);
}
if !wan_health.is_healthy {
warn!("WAN Polling service is unhealthy! Consecutive failures: {}", wan_health.consecutive_failures);
}
// 如果连续失败次数过多,发出错误警报
if websocket_health.consecutive_failures > 10 {
error!("WebSocket service has {} consecutive failures!", websocket_health.consecutive_failures);
}
if wan_health.consecutive_failures > 10 {
error!("WAN Polling service has {} consecutive failures!", wan_health.consecutive_failures);
}
}
}
fn log_service_health(&self, service_name: &str, health: &ConnectionHealth) {
info!(
"{}: {} | Uptime: {:.1}% | Attempts: {} | Successes: {} | Failures: {} | Consecutive Failures: {}",
service_name,
if health.is_healthy { "HEALTHY" } else { "UNHEALTHY" },
health.uptime_percentage,
health.total_attempts,
health.total_successes,
health.total_failures,
health.consecutive_failures
);
if let Some(last_success) = health.last_success {
info!("{}: Last success: {:.1}s ago", service_name, last_success.elapsed().as_secs_f64());
}
if let Some(last_failure) = health.last_failure {
info!("{}: Last failure: {:.1}s ago", service_name, last_failure.elapsed().as_secs_f64());
}
}
}

6
src/lib.rs Normal file
View File

@ -0,0 +1,6 @@
pub mod retry;
pub mod health_monitor;
// 重新导出常用的类型和函数
pub use retry::{RetryConfig, Retrier, retry_with_config, retry, retry_forever};
pub use health_monitor::{HealthMonitor, ServiceType, ConnectionHealth};

View File

@ -2,9 +2,9 @@ use std::time::Duration;
use anyhow::{anyhow, bail}; use anyhow::{anyhow, bail};
use clap::Parser; use clap::Parser;
use futures_util::StreamExt; use futures_util::{SinkExt, StreamExt};
use log::{debug, error}; use log::{debug, error, info, warn};
use tokio::time::sleep; use tokio::time::timeout;
use tokio_tungstenite::{ use tokio_tungstenite::{
connect_async, connect_async,
tungstenite::{ tungstenite::{
@ -15,10 +15,14 @@ use tokio_tungstenite::{
}; };
mod clash_conn_msg; mod clash_conn_msg;
mod health_monitor;
pub mod retry;
mod statistics; mod statistics;
mod udp_server; mod udp_server;
mod wan; mod wan;
use health_monitor::{HealthMonitor, ServiceType};
use retry::retry_forever;
use wan::poll_wan_traffic; use wan::poll_wan_traffic;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -55,42 +59,87 @@ async fn main() {
let args = Args::parse(); let args = Args::parse();
tokio::spawn(async move { tokio::spawn(async move {
loop { // Use infinite retry mechanism to ensure WAN traffic monitoring is always available
if let Err(err) = poll_wan_traffic( retry_forever(|| async {
let health_monitor = HealthMonitor::global().await;
health_monitor.record_attempt(ServiceType::WanPolling).await;
match poll_wan_traffic(
args.luci_url.as_str(), args.luci_url.as_str(),
args.luci_username.as_str(), args.luci_username.as_str(),
args.luci_password.as_str(), args.luci_password.as_str(),
) )
.await .await
{ {
error!("error: {}", err); Ok(_) => {
info!("WAN traffic polling ended normally, restarting...");
health_monitor.record_success(ServiceType::WanPolling).await;
Ok(())
}
Err(err) => {
error!("WAN traffic polling failed: {}", err);
health_monitor.record_failure(ServiceType::WanPolling).await;
Err(err)
}
} }
}).await;
sleep(Duration::from_secs(1)).await;
error!("restart poll_wan_traffic!");
}
}); });
let connect_addr = args.clash_url; let connect_addr = args.clash_url;
loop { // Use infinite retry mechanism to ensure WebSocket connection is always available
if let Err(err) = pipe(connect_addr.clone()).await { retry_forever(|| async {
error!("{}", err); let health_monitor = HealthMonitor::global().await;
}; health_monitor.record_attempt(ServiceType::WebSocket).await;
sleep(Duration::from_secs(1)).await;
error!("restart clash!"); match pipe(connect_addr.clone()).await {
} Ok(_) => {
info!("WebSocket connection ended normally, reconnecting...");
health_monitor.record_success(ServiceType::WebSocket).await;
Ok(())
}
Err(err) => {
error!("WebSocket connection failed: {}", err);
health_monitor.record_failure(ServiceType::WebSocket).await;
Err(err)
}
}
}).await;
} }
async fn pipe(connect_addr: String) -> anyhow::Result<()> { async fn pipe(connect_addr: String) -> anyhow::Result<()> {
info!("Attempting to connect to WebSocket: {}", connect_addr);
let request = connect_addr.into_client_request()?; let request = connect_addr.into_client_request()?;
let (mut ws_stream, _) = connect_async(request).await.map_err(|err| anyhow::anyhow!(err))?; // Add connection timeout
println!("WebSocket handshake has been successfully completed"); let (mut ws_stream, _) = timeout(Duration::from_secs(30), connect_async(request))
.await
.map_err(|_| anyhow!("WebSocket connection timeout after 30 seconds"))?
.map_err(|err| anyhow!("WebSocket connection failed: {}", err))?;
while let Some(message) = ws_stream.next().await { info!("WebSocket handshake completed successfully");
let message = message?;
let heartbeat_timeout = Duration::from_secs(60); // Consider connection abnormal if no message for 60 seconds
loop {
let message_result = timeout(heartbeat_timeout, ws_stream.next()).await;
let message = match message_result {
Ok(Some(Ok(msg))) => msg,
Ok(Some(Err(err))) => {
error!("WebSocket message error: {}", err);
return Err(anyhow!("WebSocket message error: {}", err));
}
Ok(None) => {
warn!("WebSocket stream ended unexpectedly");
return Err(anyhow!("WebSocket stream ended"));
}
Err(_) => {
error!("WebSocket heartbeat timeout - no message received in {} seconds", heartbeat_timeout.as_secs());
return Err(anyhow!("WebSocket heartbeat timeout"));
}
};
match message { match message {
Message::Text(text) => { Message::Text(text) => {
@ -136,9 +185,24 @@ async fn pipe(connect_addr: String) -> anyhow::Result<()> {
buf[24..32].copy_from_slice(&state.proxy_download_speed.to_le_bytes()); buf[24..32].copy_from_slice(&state.proxy_download_speed.to_le_bytes());
udp_server.publish_clash(&buf).await; udp_server.publish_clash(&buf).await;
} }
Message::Close(_) => { Message::Close(close_frame) => {
println!("Server requested close"); if let Some(frame) = close_frame {
break; info!("Server requested close: {} - {}", frame.code, frame.reason);
} else {
info!("Server requested close without reason");
}
return Err(anyhow!("WebSocket connection closed by server"));
}
Message::Ping(payload) => {
debug!("Received ping, sending pong");
// Automatically reply pong to keep connection alive
if let Err(err) = ws_stream.send(Message::Pong(payload)).await {
error!("Failed to send pong: {}", err);
return Err(anyhow!("Failed to send pong: {}", err));
}
}
Message::Pong(_) => {
debug!("Received pong");
} }
_ => {} _ => {}
} }

261
src/retry.rs Normal file
View File

@ -0,0 +1,261 @@
use std::time::Duration;
use log::{debug, warn, error};
use tokio::time::sleep;
/// Retry strategy configuration
#[derive(Clone, Debug)]
pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_attempts: usize,
/// Initial retry delay
pub initial_delay: Duration,
/// Maximum retry delay
pub max_delay: Duration,
/// Exponential backoff multiplier
pub backoff_multiplier: f64,
/// Whether to enable jitter
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 10,
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryConfig {
/// Create fast retry configuration (for lightweight operations)
pub fn fast() -> Self {
Self {
max_attempts: 5,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
backoff_multiplier: 1.5,
jitter: true,
}
}
/// Create slow retry configuration (for heavyweight operations)
pub fn slow() -> Self {
Self {
max_attempts: 15,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
backoff_multiplier: 2.0,
jitter: true,
}
}
/// Create infinite retry configuration (for critical services)
pub fn infinite() -> Self {
Self {
max_attempts: usize::MAX,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
/// Retrier implementation
pub struct Retrier {
config: RetryConfig,
attempt: usize,
current_delay: Duration,
}
impl Retrier {
pub fn new(config: RetryConfig) -> Self {
Self {
current_delay: config.initial_delay,
config,
attempt: 0,
}
}
/// Execute retry operation
pub async fn retry<F, Fut, T, E>(&mut self, operation: F) -> Result<T, E>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
loop {
self.attempt += 1;
debug!("Attempting operation (attempt {}/{})", self.attempt, self.config.max_attempts);
match operation().await {
Ok(result) => {
if self.attempt > 1 {
debug!("Operation succeeded after {} attempts", self.attempt);
}
return Ok(result);
}
Err(err) => {
if self.attempt >= self.config.max_attempts {
error!("Operation failed after {} attempts: {}", self.attempt, err);
return Err(err);
}
warn!("Operation failed (attempt {}/{}): {}", self.attempt, self.config.max_attempts, err);
let delay = self.calculate_delay();
debug!("Retrying in {:?}", delay);
sleep(delay).await;
}
}
}
}
/// Reset retrier state
pub fn reset(&mut self) {
self.attempt = 0;
self.current_delay = self.config.initial_delay;
}
/// Calculate next retry delay
fn calculate_delay(&mut self) -> Duration {
let delay = self.current_delay;
// Calculate next delay time (exponential backoff)
let next_delay_ms = (self.current_delay.as_millis() as f64 * self.config.backoff_multiplier) as u64;
self.current_delay = Duration::from_millis(next_delay_ms).min(self.config.max_delay);
// Add jitter to avoid thundering herd effect
if self.config.jitter {
let jitter_range = delay.as_millis() as f64 * 0.1; // 10% jitter
let jitter = (rand::random::<f64>() - 0.5) * 2.0 * jitter_range;
let jittered_delay = (delay.as_millis() as f64 + jitter).max(0.0) as u64;
Duration::from_millis(jittered_delay)
} else {
delay
}
}
/// Get current attempt count
pub fn attempt_count(&self) -> usize {
self.attempt
}
}
/// Convenience function: execute operation with retry
pub async fn retry_with_config<F, Fut, T, E>(
config: RetryConfig,
operation: F,
) -> Result<T, E>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let mut retrier = Retrier::new(config);
retrier.retry(operation).await
}
/// Convenience function: execute retry operation with default configuration
pub async fn retry<F, Fut, T, E>(operation: F) -> Result<T, E>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
retry_with_config(RetryConfig::default(), operation).await
}
/// Convenience function: infinite retry (for critical services)
pub async fn retry_forever<F, Fut, T, E>(operation: F) -> T
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let mut retrier = Retrier::new(RetryConfig::infinite());
loop {
match retrier.retry(&operation).await {
Ok(result) => return result,
Err(_) => {
// This should theoretically never happen since we set infinite retry
// But for safety, we reset the retrier and continue
retrier.reset();
continue;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_success_on_first_attempt() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = retry(|| async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok::<i32, &str>(42)
}).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_failures() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
backoff_multiplier: 2.0,
jitter: false,
};
let result = retry_with_config(config, || async {
let count = counter_clone.fetch_add(1, Ordering::SeqCst) + 1;
if count < 3 {
Err("not ready")
} else {
Ok(42)
}
}).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_max_attempts_exceeded() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let config = RetryConfig {
max_attempts: 2,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
backoff_multiplier: 2.0,
jitter: false,
};
let result = retry_with_config(config, || async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Err::<i32, &str>("always fails")
}).await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
}

View File

@ -11,16 +11,18 @@ use hyper::{
Request, Request,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use log::debug; use log::{debug, error, info, warn};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::sleep; use tokio::time::{sleep, timeout};
use crate::udp_server; use crate::udp_server;
pub type OpenWRTIFaces = Vec<OpenWRTIFace>; pub type OpenWRTIFaces = Vec<OpenWRTIFace>;
pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) -> Result<()> { pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) -> Result<()> {
info!("Starting WAN traffic polling for: {}", luci_url);
// Parse our URL... // Parse our URL...
let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::<hyper::Uri>()?; let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::<hyper::Uri>()?;
@ -29,31 +31,44 @@ pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) ->
let port = url.port_u16().unwrap_or(80); let port = url.port_u16().unwrap_or(80);
let address = format!("{}:{}", host, port); let address = format!("{}:{}", host, port);
info!("Connecting to: {}", address);
let mut cookies_str = String::new(); let mut cookies_str = String::new();
// Open a TCP connection to the remote host let connection_timeout = Duration::from_secs(10);
let request_timeout = Duration::from_secs(30);
// Open a TCP connection to the remote host with timeout
loop { loop {
let stream = TcpStream::connect(address.clone()).await?; let stream = timeout(connection_timeout, TcpStream::connect(address.clone()))
.await
.map_err(|_| anyhow!("Connection timeout after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| anyhow!("Failed to connect to {}: {}", address, e))?;
// Use an adapter to access something implementing `tokio::io` traits as if they implement // Use an adapter to access something implementing `tokio::io` traits as if they implement
// `hyper::rt` IO traits. // `hyper::rt` IO traits.
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
// Create the Hyper client // Create the Hyper client with timeout
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; let (mut sender, conn) = timeout(connection_timeout, hyper::client::conn::http1::handshake(io))
.await
.map_err(|_| anyhow!("HTTP handshake timeout after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| anyhow!("HTTP handshake failed: {}", e))?;
// Spawn a task to poll the connection, driving the HTTP state // Spawn a task to poll the connection, driving the HTTP state
tokio::task::spawn(async move { let conn_handle = tokio::task::spawn(async move {
if let Err(err) = conn.await { if let Err(err) = conn.await {
println!("Connection failed: {:?}", err); error!("HTTP connection failed: {:?}", err);
} }
}); });
let mut prev_up = 0; let mut prev_up = 0;
let mut prev_down = 0; let mut prev_down = 0;
let mut consecutive_errors = 0;
const MAX_CONSECUTIVE_ERRORS: u32 = 5;
loop { loop {
if sender.is_closed() { if sender.is_closed() {
warn!("HTTP sender is closed, breaking connection loop");
break; break;
} }
@ -64,29 +79,52 @@ pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) ->
let poll_req = generate_poll_wan_req(luci_url, &cookies_str)?; let poll_req = generate_poll_wan_req(luci_url, &cookies_str)?;
// Await the response... // Send request with timeout
let res = timeout(request_timeout, sender.send_request(poll_req.clone()))
let res = sender.send_request(poll_req.clone()).await?; .await
.map_err(|_| anyhow!("Request timeout after {} seconds", request_timeout.as_secs()))?
.map_err(|e| anyhow!("Request failed: {}", e))?;
debug!("Response status: {}", res.status()); debug!("Response status: {}", res.status());
match res.status() { match res.status() {
StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => { StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => {
info!("Authentication required, logging in...");
cookies_str = login(luci_url, username, password).await?; cookies_str = login(luci_url, username, password).await?;
consecutive_errors = 0; // Reset error count
continue;
}
StatusCode::OK => {
consecutive_errors = 0; // Reset error count
}
status => {
consecutive_errors += 1;
error!("Unexpected response status: {} (consecutive errors: {})", status, consecutive_errors);
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
return Err(anyhow!("Too many consecutive errors ({}), giving up", consecutive_errors));
}
sleep(Duration::from_secs(1)).await;
continue; continue;
} }
_ => {}
} }
// asynchronously aggregate the chunks of the body // asynchronously aggregate the chunks of the body with timeout
let body = res.collect().await?.aggregate(); let body = timeout(request_timeout, res.collect())
.await
.map_err(|_| anyhow!("Response body timeout after {} seconds", request_timeout.as_secs()))?
.map_err(|e| anyhow!("Failed to read response body: {}", e))?
.aggregate();
// try to parse as json with serde_json // try to parse as json with serde_json
let interfaces: OpenWRTIFaces = serde_json::from_reader(body.reader())?; let interfaces: OpenWRTIFaces = serde_json::from_reader(body.reader())
.map_err(|e| anyhow!("Failed to parse JSON response: {}", e))?;
let wan = interfaces.first().unwrap(); let wan = interfaces.first()
.ok_or_else(|| anyhow!("No WAN interface found in response"))?;
// Detect counter reset (restart scenarios)
if prev_up > wan.tx_bytes || prev_down > wan.rx_bytes { if prev_up > wan.tx_bytes || prev_down > wan.rx_bytes {
info!("Counter reset detected, resetting baseline values");
prev_up = wan.tx_bytes; prev_up = wan.tx_bytes;
prev_down = wan.rx_bytes; prev_down = wan.rx_bytes;
} }
@ -96,48 +134,81 @@ pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) ->
prev_up = wan.tx_bytes; prev_up = wan.tx_bytes;
prev_down = wan.rx_bytes; prev_down = wan.rx_bytes;
debug!("WAN traffic: ↑ {} bytes/s, ↓ {} bytes/s", up_speed, down_speed);
let udp_server = udp_server::UdpServer::global().await; let udp_server = udp_server::UdpServer::global().await;
let mut buf = [0; 16]; let mut buf = [0; 16];
buf[0..8].copy_from_slice(&up_speed.to_le_bytes()); buf[0..8].copy_from_slice(&up_speed.to_le_bytes());
buf[8..16].copy_from_slice(&down_speed.to_le_bytes()); buf[8..16].copy_from_slice(&down_speed.to_le_bytes());
udp_server.publish_wan(&buf).await; udp_server.publish_wan(&buf).await;
// println!("speed: ↑ {}, ↓ {}", up_speed, down_speed);
sleep(Duration::from_secs(1)).await; sleep(Duration::from_secs(1)).await;
} }
} }
} }
async fn login(luci_url: &str, username: &str, password: &str) -> Result<String> { async fn login(luci_url: &str, username: &str, password: &str) -> Result<String> {
info!("Attempting to login to: {}", luci_url);
let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::<hyper::Uri>()?; let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::<hyper::Uri>()?;
let host = url.host().expect("uri has no host"); let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(80); let port = url.port_u16().unwrap_or(80);
let address = format!("{}:{}", host, port); let address = format!("{}:{}", host, port);
let stream = TcpStream::connect(address).await?; let connection_timeout = Duration::from_secs(10);
let request_timeout = Duration::from_secs(30);
let stream = timeout(connection_timeout, TcpStream::connect(address.clone()))
.await
.map_err(|_| anyhow!("Login connection timeout after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| anyhow!("Failed to connect for login to {}: {}", address, e))?;
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; let (mut sender, conn) = timeout(connection_timeout, hyper::client::conn::http1::handshake(io))
tokio::task::spawn(async move { .await
.map_err(|_| anyhow!("Login handshake timeout after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| anyhow!("Login handshake failed: {}", e))?;
let _conn_handle = tokio::task::spawn(async move {
if let Err(err) = conn.await { if let Err(err) = conn.await {
println!("Connection failed: {:?}", err); error!("Login connection failed: {:?}", err);
} }
}); });
let login_req = generate_login_req(luci_url, username, password)?; let login_req = generate_login_req(luci_url, username, password)?;
let res = sender.send_request(login_req.clone()).await?; let res = timeout(request_timeout, sender.send_request(login_req.clone()))
.await
.map_err(|_| anyhow!("Login request timeout after {} seconds", request_timeout.as_secs()))?
.map_err(|e| anyhow!("Login request failed: {}", e))?;
if res.status() == StatusCode::FORBIDDEN { match res.status() {
bail!("Login failed, got status: {}", res.status()); StatusCode::OK | StatusCode::FOUND | StatusCode::SEE_OTHER => {
info!("Login successful");
}
StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => {
error!("Login failed: Invalid credentials");
bail!("Login failed with status: {}", res.status());
}
status => {
error!("Login failed with unexpected status: {}", status);
bail!("Login failed with status: {}", status);
}
} }
let cookies = res.headers().get_all(hyper::header::SET_COOKIE); let cookies = res.headers().get_all(hyper::header::SET_COOKIE);
let cookies = cookies let cookies = cookies
.iter() .iter()
.map(|cookie| cookie.to_str().unwrap().split(';').next()); .filter_map(|cookie| cookie.to_str().ok())
let cookies = cookies.filter_map(|cookie| cookie); .filter_map(|cookie| cookie.split(';').next())
let cookies_str = cookies.collect::<Vec<_>>().join("; "); .collect::<Vec<_>>();
if cookies.is_empty() {
warn!("No cookies received from login response");
}
let cookies_str = cookies.join("; ");
debug!("Received cookies: {}", cookies_str);
Ok(cookies_str) Ok(cookies_str)
} }

148
tests/integration_test.rs Normal file
View File

@ -0,0 +1,148 @@
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use network_monitor::retry::{RetryConfig, retry_with_config};
/// 模拟网络故障的测试
#[tokio::test]
async fn test_network_failure_recovery() {
// 模拟一个会失败几次然后成功的操作
let attempt_count = Arc::new(AtomicU32::new(0));
let max_failures = 3;
let config = RetryConfig {
max_attempts: 10,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(100),
backoff_multiplier: 1.5,
jitter: false, // 关闭抖动以便测试更可预测
};
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
if current_attempt <= max_failures {
// 模拟网络错误
Err(format!("Network error on attempt {}", current_attempt))
} else {
// 模拟恢复成功
Ok(format!("Success on attempt {}", current_attempt))
}
}
}).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Success on attempt 4");
assert_eq!(attempt_count.load(Ordering::SeqCst), 4);
}
/// 测试连接超时场景
#[tokio::test]
async fn test_connection_timeout_scenario() {
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(5),
max_delay: Duration::from_millis(20),
backoff_multiplier: 2.0,
jitter: false,
};
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result: Result<String, &str> = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
attempt_count.fetch_add(1, Ordering::SeqCst);
// 模拟连接超时
sleep(Duration::from_millis(1)).await;
Err("Connection timeout")
}
}).await;
assert!(result.is_err());
assert_eq!(attempt_count.load(Ordering::SeqCst), 3); // 应该尝试了3次
}
/// 测试快速恢复场景
#[tokio::test]
async fn test_fast_recovery() {
let config = RetryConfig::fast();
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
if current_attempt == 1 {
Err("Temporary failure")
} else {
Ok("Quick recovery")
}
}
}).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Quick recovery");
assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
}
/// 测试慢速重试场景
#[tokio::test]
async fn test_slow_retry_scenario() {
let config = RetryConfig::slow();
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
if current_attempt <= 2 {
Err("Service unavailable")
} else {
Ok("Service restored")
}
}
}).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Service restored");
assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
}
/// 测试最大重试次数限制
#[tokio::test]
async fn test_max_retry_limit() {
let config = RetryConfig {
max_attempts: 2,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(5),
backoff_multiplier: 2.0,
jitter: false,
};
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
attempt_count.fetch_add(1, Ordering::SeqCst);
Err::<String, &str>("Persistent failure")
}
}).await;
assert!(result.is_err());
assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
}