Compare commits

...

8 Commits

Author SHA1 Message Date
b419463bb4 Update dependencies to latest versions
All checks were successful
Gitea Actions Demo / build (push) Successful in 1m12s
- Update rand from 0.8.5 to 0.9.1
- Update tokio-tungstenite from 0.24.0 to 0.27.0
- Update all other dependencies to latest compatible versions
- Fix API breaking change: replace text.into_bytes() with text.as_bytes()
- All tests passing, no functionality changes
2025-06-30 17:55:41 +08:00
e2bd5e9be5 Clean up Chinese comments and add comprehensive English README
- Replace all Chinese comments with English equivalents in:
  - src/health_monitor.rs
  - src/lib.rs
  - tests/integration_test.rs
- Add comprehensive README.md with:
  - Project overview and features
  - Architecture diagram
  - Installation and configuration guide
  - Data format specifications
  - Health monitoring documentation
  - Troubleshooting guide
2025-06-30 17:40:37 +08:00
2a9e34d345 feat: implement robust retry mechanism for API aggregation service
All checks were successful
Gitea Actions Demo / build (push) Successful in 4m54s
- Add unified retry strategy with exponential backoff and jitter
- Implement health monitoring system for service status tracking
- Enhance WebSocket connection with timeout and heartbeat detection
- Improve WAN polling with connection timeout and error handling
- Add comprehensive test suite for retry mechanisms
- Fix connection hanging issues that prevented proper retry recovery

Key improvements:
- RetryConfig with fast/slow/infinite retry strategies
- HealthMonitor for real-time service status tracking
- Connection timeouts to prevent hanging
- Automatic ping/pong handling for WebSocket
- Consecutive error counting and thresholds
- Detailed logging for better diagnostics

All tests passing: 11 tests (3 unit + 5 integration + 3 lib tests)
2025-06-30 16:49:34 +08:00
7fd8639c5e chore: update deps.
All checks were successful
Gitea Actions Demo / build (push) Successful in 4m21s
Signed-off-by: Ivan Li <ivanli2048@gmail.com>
2024-11-17 10:00:46 +08:00
c94243b838 chore: rm .env.
All checks were successful
Gitea Actions Demo / build (push) Successful in 1m1s
2024-05-03 04:49:06 +00:00
bb015bf662 fix: auto reconnection x2. 2024-05-03 04:00:27 +00:00
0bedfb261c fix: auto reconnection. 2024-05-02 05:06:47 +00:00
ca110fd56d feat: adding OpenWRT WAN port speeds. 2024-03-30 17:47:53 +08:00
14 changed files with 2318 additions and 470 deletions

View File

@@ -3,7 +3,7 @@
{ {
"name": "Rust", "name": "Rust",
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/rust:0-1-bullseye", "image": "mcr.microsoft.com/devcontainers/rust:1-bullseye",
"features": { "features": {
"ghcr.io/devcontainers/features/git:1": {}, "ghcr.io/devcontainers/features/git:1": {},
"ghcr.io/devcontainers/features/docker-in-docker:2": {} "ghcr.io/devcontainers/features/docker-in-docker:2": {}
@@ -13,8 +13,9 @@
"extensions": [ "extensions": [
"mhutchie.git-graph", "mhutchie.git-graph",
"donjayamanne.githistory", "donjayamanne.githistory",
"GitHub.copilot", "eamodio.gitlens",
"eamodio.gitlens" "rust-lang.rust-analyzer",
"Codeium.codeium"
] ]
} }
} }

1
.gitignore vendored
View File

@@ -1 +1,2 @@
/target /target
.env

45
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,45 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug executable 'network-monitor'",
"cargo": {
"args": [
"build",
"--bin=network-monitor",
"--package=network-monitor"
],
"filter": {
"name": "network-monitor",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'network-monitor'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=network-monitor",
"--package=network-monitor"
],
"filter": {
"name": "network-monitor",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

6
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,6 @@
{
"cSpell.words": [
"luci"
],
"lldb.library": "/Library/Developer/CommandLineTools/Library/PrivateFrameworks/LLDB.framework/Versions/A/LLDB"
}

1263
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -6,13 +6,18 @@ version = "0.1.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
anyhow = "1.0.71" anyhow = "1.0.93"
clap = {version = "4.2.7", features = ["derive"]} clap = {version = "4.5.21", features = ["derive", "env"]}
env_logger = "0.10.0" dotenvy = { version = "0.15.7", features = ["clap", "cli"] }
futures-util = "0.3.28" env_logger = "0.11.5"
log = "0.4.17" futures-util = "0.3.31"
serde = { version = "1.0.163", features = ["derive"] } http-body-util = "0.1"
serde_json = "1.0.96" hyper = {version = "1.5.0", features = ["full"]}
tokio = {version = "1.11.0", features = ["full"]} hyper-util = {version = "0.1", features = ["full"]}
tokio-tungstenite = "0.18.0" log = "0.4.22"
url = "2.3.1" rand = "0.9.1"
serde = {version = "1.0.215", features = ["derive"]}
serde_json = "1.0.133"
tokio = {version = "1.41.1", features = ["full"]}
tokio-tungstenite = "0.27.0"
url = "2.5.3"

252
README.md Normal file
View File

@@ -0,0 +1,252 @@
# Network Monitor
A robust network monitoring service written in Rust that tracks network traffic from multiple sources and provides real-time data via UDP broadcasting. This service is designed for high availability with automatic retry mechanisms and comprehensive health monitoring.
## Features
### 🔄 Dual Network Monitoring
- **Clash Proxy Monitoring**: Connects to Clash proxy via WebSocket to monitor proxy traffic statistics
- **WAN Interface Monitoring**: Polls OpenWRT/LuCI router interfaces for WAN traffic data
### 🚀 High Availability
- **Infinite Retry Mechanism**: Automatically recovers from network failures and service interruptions
- **Health Monitoring**: Comprehensive health tracking with detailed statistics and alerting
- **Exponential Backoff**: Smart retry strategy with configurable delays and jitter
### 📡 Real-time Data Broadcasting
- **UDP Server**: Broadcasts network statistics to connected clients
- **Client Management**: Automatic client discovery and connection management
- **Data Formats**: Structured binary data for efficient transmission
### 🛡️ Robust Error Handling
- **Connection Timeouts**: Configurable timeouts for all network operations
- **Graceful Degradation**: Continues operation even when one monitoring source fails
- **Detailed Logging**: Comprehensive logging for debugging and monitoring
## Architecture
```
┌─────────────────┐ WebSocket ┌─────────────────┐
│ Clash Proxy │◄───────────────►│ │
└─────────────────┘ │ │
│ Network │ UDP
┌─────────────────┐ HTTP/LuCI │ Monitor │◄──────────┐
│ OpenWRT Router │◄───────────────►│ Service │ │
└─────────────────┘ │ │ │
└─────────────────┘ │
┌─────────────────┐ UDP Data ┌─────────────────┐ │
│ Client 1 │◄───────────────►│ UDP Server │◄──────────┘
└─────────────────┘ └─────────────────┘
┌─────────────────┐
│ Client 2 │◄───────────────►
└─────────────────┘
```
## Installation
### Prerequisites
- Rust 1.70+ (for building from source)
- Docker (for containerized deployment)
### Building from Source
```bash
# Clone the repository
git clone <repository-url>
cd network-monitor
# Build the project
cargo build --release
# Run tests
cargo test
# Run the service
cargo run
```
### Docker Deployment
```bash
# Build the Docker image
docker build -t network-monitor .
# Run the container
docker run -d \
--name network-monitor \
-p 17890:17890/udp \
-e CLASH_URL="ws://192.168.1.1:9090/connections?token=your-token" \
-e LUCI_URL="http://192.168.1.1/cgi-bin/luci" \
-e LUCI_USERNAME="root" \
-e LUCI_PASSWORD="your-password" \
network-monitor
```
## Configuration
The service can be configured via command-line arguments or environment variables:
| Parameter | Environment Variable | Default Value | Description |
|-----------|---------------------|---------------|-------------|
| `-c, --clash-url` | `CLASH_URL` | `ws://192.168.1.1:9090/connections?token=123456` | Clash WebSocket URL |
| `-p, --listen-port` | `LISTEN_PORT` | `17890` | UDP server listen port |
| `-l, --luci-url` | `LUCI_URL` | `http://192.168.1.1/cgi-bin/luci` | OpenWRT LuCI base URL |
| `-u, --luci-username` | `LUCI_USERNAME` | `root` | LuCI authentication username |
| `-P, --luci-password` | `LUCI_PASSWORD` | `123456` | LuCI authentication password |
### Environment File
Create a `.env` file in the project root:
```env
CLASH_URL=ws://192.168.1.1:9090/connections?token=your-clash-token
LISTEN_PORT=17890
LUCI_URL=http://192.168.1.1/cgi-bin/luci
LUCI_USERNAME=root
LUCI_PASSWORD=your-router-password
```
## Data Formats
### Clash Traffic Data (32 bytes)
```
Bytes 0-7: Direct upload speed (u64, little-endian)
Bytes 8-15: Direct download speed (u64, little-endian)
Bytes 16-23: Proxy upload speed (u64, little-endian)
Bytes 24-31: Proxy download speed (u64, little-endian)
```
### WAN Traffic Data (16 bytes)
```
Bytes 0-7: WAN upload speed (u64, little-endian)
Bytes 8-15: WAN download speed (u64, little-endian)
```
## Health Monitoring
The service includes comprehensive health monitoring with the following metrics:
- **Connection Status**: Real-time health status for each service
- **Uptime Percentage**: Success rate over time
- **Failure Tracking**: Consecutive failure counts and timestamps
- **Performance Metrics**: Total attempts, successes, and failures
Health reports are logged every minute with detailed statistics.
## Retry Strategy
The service implements a sophisticated retry mechanism:
- **Infinite Retries**: Critical services never give up
- **Exponential Backoff**: Delays increase exponentially with failures
- **Jitter**: Random delays prevent thundering herd effects
- **Configurable Limits**: Maximum delays and retry counts can be customized
### Retry Configurations
- **Fast Retry**: For lightweight operations (5 attempts, 100ms-5s delays)
- **Default Retry**: Balanced approach (10 attempts, 500ms-30s delays)
- **Slow Retry**: For heavyweight operations (15 attempts, 1s-60s delays)
- **Infinite Retry**: For critical services (unlimited attempts)
## Logging
The service uses structured logging with multiple levels:
- **INFO**: Normal operation events and health reports
- **WARN**: Service health issues and recoverable errors
- **ERROR**: Critical failures and persistent issues
- **DEBUG**: Detailed operation information
Set the `RUST_LOG` environment variable to control log levels:
```bash
export RUST_LOG=info # or debug, warn, error
```
## Development
### Project Structure
```
src/
├── main.rs # Application entry point
├── lib.rs # Library exports
├── clash_conn_msg.rs # Clash message structures
├── health_monitor.rs # Health monitoring system
├── retry.rs # Retry mechanism implementation
├── statistics.rs # Traffic statistics processing
├── udp_server.rs # UDP broadcasting server
└── wan.rs # WAN traffic polling
tests/
└── integration_test.rs # Integration tests
```
### Running Tests
```bash
# Run all tests
cargo test
# Run with output
cargo test -- --nocapture
# Run specific test
cargo test test_network_failure_recovery
```
### Contributing
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests for new functionality
5. Ensure all tests pass
6. Submit a pull request
## License
This project is licensed under the MIT License - see the LICENSE file for details.
## Troubleshooting
### Common Issues
1. **WebSocket Connection Failures**
- Verify Clash is running and accessible
- Check the WebSocket URL and authentication token
- Ensure network connectivity to the Clash instance
2. **LuCI Authentication Failures**
- Verify router credentials
- Check if the router is accessible
- Ensure the LuCI interface is enabled
3. **UDP Client Connection Issues**
- Verify the UDP port is not blocked by firewall
- Check if the service is binding to the correct interface
- Ensure clients are connecting to the correct port
### Debug Mode
Enable debug logging for detailed troubleshooting:
```bash
RUST_LOG=debug cargo run
```
This will provide detailed information about:
- Connection attempts and failures
- Retry mechanisms in action
- Health monitoring decisions
- UDP client management
- Data processing and broadcasting

233
src/health_monitor.rs Normal file
View File

@@ -0,0 +1,233 @@
use std::sync::atomic::{AtomicU64, AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, OnceCell};
use tokio::time::sleep;
use log::{info, warn, error};
/// Connection health status
#[derive(Debug, Clone)]
pub struct ConnectionHealth {
pub is_healthy: bool,
pub last_success: Option<Instant>,
pub last_failure: Option<Instant>,
pub consecutive_failures: u32,
pub total_attempts: u64,
pub total_successes: u64,
pub total_failures: u64,
pub uptime_percentage: f64,
}
/// Service type
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServiceType {
WebSocket,
WanPolling,
}
impl std::fmt::Display for ServiceType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServiceType::WebSocket => write!(f, "WebSocket"),
ServiceType::WanPolling => write!(f, "WAN Polling"),
}
}
}
/// Health monitor
pub struct HealthMonitor {
websocket_stats: Arc<ServiceStats>,
wan_polling_stats: Arc<ServiceStats>,
}
struct ServiceStats {
is_healthy: AtomicBool,
last_success: Mutex<Option<Instant>>,
last_failure: Mutex<Option<Instant>>,
consecutive_failures: AtomicU64,
total_attempts: AtomicU64,
total_successes: AtomicU64,
total_failures: AtomicU64,
start_time: Instant,
}
impl ServiceStats {
fn new() -> Self {
Self {
is_healthy: AtomicBool::new(false),
last_success: Mutex::new(None),
last_failure: Mutex::new(None),
consecutive_failures: AtomicU64::new(0),
total_attempts: AtomicU64::new(0),
total_successes: AtomicU64::new(0),
total_failures: AtomicU64::new(0),
start_time: Instant::now(),
}
}
async fn record_attempt(&self) {
self.total_attempts.fetch_add(1, Ordering::Relaxed);
}
async fn record_success(&self) {
self.is_healthy.store(true, Ordering::Relaxed);
self.consecutive_failures.store(0, Ordering::Relaxed);
self.total_successes.fetch_add(1, Ordering::Relaxed);
*self.last_success.lock().await = Some(Instant::now());
}
async fn record_failure(&self) {
self.is_healthy.store(false, Ordering::Relaxed);
self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
self.total_failures.fetch_add(1, Ordering::Relaxed);
*self.last_failure.lock().await = Some(Instant::now());
}
async fn get_health(&self) -> ConnectionHealth {
let total_attempts = self.total_attempts.load(Ordering::Relaxed);
let total_successes = self.total_successes.load(Ordering::Relaxed);
let total_failures = self.total_failures.load(Ordering::Relaxed);
let uptime_percentage = if total_attempts > 0 {
(total_successes as f64 / total_attempts as f64) * 100.0
} else {
0.0
};
ConnectionHealth {
is_healthy: self.is_healthy.load(Ordering::Relaxed),
last_success: *self.last_success.lock().await,
last_failure: *self.last_failure.lock().await,
consecutive_failures: self.consecutive_failures.load(Ordering::Relaxed) as u32,
total_attempts,
total_successes,
total_failures,
uptime_percentage,
}
}
}
impl HealthMonitor {
pub async fn global() -> &'static Self {
static HEALTH_MONITOR: OnceCell<HealthMonitor> = OnceCell::const_new();
HEALTH_MONITOR
.get_or_init(|| async {
let monitor = HealthMonitor::new();
// Start health status reporting task
let monitor_clone = monitor.clone();
tokio::spawn(async move {
monitor_clone.start_health_reporting().await;
});
monitor
})
.await
}
fn new() -> Self {
Self {
websocket_stats: Arc::new(ServiceStats::new()),
wan_polling_stats: Arc::new(ServiceStats::new()),
}
}
fn clone(&self) -> Self {
Self {
websocket_stats: self.websocket_stats.clone(),
wan_polling_stats: self.wan_polling_stats.clone(),
}
}
/// Record service attempt
pub async fn record_attempt(&self, service: ServiceType) {
match service {
ServiceType::WebSocket => self.websocket_stats.record_attempt().await,
ServiceType::WanPolling => self.wan_polling_stats.record_attempt().await,
}
}
/// Record service success
pub async fn record_success(&self, service: ServiceType) {
match service {
ServiceType::WebSocket => self.websocket_stats.record_success().await,
ServiceType::WanPolling => self.wan_polling_stats.record_success().await,
}
}
/// Record service failure
pub async fn record_failure(&self, service: ServiceType) {
match service {
ServiceType::WebSocket => self.websocket_stats.record_failure().await,
ServiceType::WanPolling => self.wan_polling_stats.record_failure().await,
}
}
/// Get service health status
pub async fn get_health(&self, service: ServiceType) -> ConnectionHealth {
match service {
ServiceType::WebSocket => self.websocket_stats.get_health().await,
ServiceType::WanPolling => self.wan_polling_stats.get_health().await,
}
}
/// Get all services health status
pub async fn get_all_health(&self) -> (ConnectionHealth, ConnectionHealth) {
let websocket_health = self.websocket_stats.get_health().await;
let wan_polling_health = self.wan_polling_stats.get_health().await;
(websocket_health, wan_polling_health)
}
/// Start health status reporting task
async fn start_health_reporting(&self) {
let mut interval = tokio::time::interval(Duration::from_secs(60)); // Report every minute
loop {
interval.tick().await;
let (websocket_health, wan_health) = self.get_all_health().await;
info!("=== Health Status Report ===");
self.log_service_health("WebSocket", &websocket_health);
self.log_service_health("WAN Polling", &wan_health);
// Warn if any service is unhealthy
if !websocket_health.is_healthy {
warn!("WebSocket service is unhealthy! Consecutive failures: {}", websocket_health.consecutive_failures);
}
if !wan_health.is_healthy {
warn!("WAN Polling service is unhealthy! Consecutive failures: {}", wan_health.consecutive_failures);
}
// Alert if consecutive failures are too many
if websocket_health.consecutive_failures > 10 {
error!("WebSocket service has {} consecutive failures!", websocket_health.consecutive_failures);
}
if wan_health.consecutive_failures > 10 {
error!("WAN Polling service has {} consecutive failures!", wan_health.consecutive_failures);
}
}
}
fn log_service_health(&self, service_name: &str, health: &ConnectionHealth) {
info!(
"{}: {} | Uptime: {:.1}% | Attempts: {} | Successes: {} | Failures: {} | Consecutive Failures: {}",
service_name,
if health.is_healthy { "HEALTHY" } else { "UNHEALTHY" },
health.uptime_percentage,
health.total_attempts,
health.total_successes,
health.total_failures,
health.consecutive_failures
);
if let Some(last_success) = health.last_success {
info!("{}: Last success: {:.1}s ago", service_name, last_success.elapsed().as_secs_f64());
}
if let Some(last_failure) = health.last_failure {
info!("{}: Last failure: {:.1}s ago", service_name, last_failure.elapsed().as_secs_f64());
}
}
}

6
src/lib.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod retry;
pub mod health_monitor;
// Re-export commonly used types and functions
pub use retry::{RetryConfig, Retrier, retry_with_config, retry, retry_forever};
pub use health_monitor::{HealthMonitor, ServiceType, ConnectionHealth};

View File

@@ -1,19 +1,31 @@
use std::{format, time::Duration}; use std::time::Duration;
use anyhow::{anyhow, bail};
use clap::Parser; use clap::Parser;
use futures_util::{pin_mut, StreamExt}; use futures_util::{SinkExt, StreamExt};
use log::info; use log::{debug, error, info, warn};
use tokio::{ use tokio::time::timeout;
io::{stdout, AsyncWriteExt}, use tokio_tungstenite::{
time::sleep, connect_async,
tungstenite::{
client::IntoClientRequest,
protocol::{frame::coding::CloseCode, CloseFrame},
Message,
},
}; };
use tokio_tungstenite::connect_async;
mod clash_conn_msg; mod clash_conn_msg;
mod health_monitor;
pub mod retry;
mod statistics; mod statistics;
mod udp_server; mod udp_server;
mod wan;
#[derive(Parser)] use health_monitor::{HealthMonitor, ServiceType};
use retry::retry_forever;
use wan::poll_wan_traffic;
#[derive(Parser, Debug)]
#[clap( #[clap(
version = "0.1.0", version = "0.1.0",
author = "Ivan Li", author = "Ivan Li",
@@ -22,87 +34,179 @@ mod udp_server;
struct Args { struct Args {
#[clap( #[clap(
short, short,
env,
long, long,
default_value = "ws://192.168.31.1:9090/connections?token=123456" default_value = "ws://192.168.1.1:9090/connections?token=123456"
)] )]
clash_url: String, clash_url: String,
#[clap(short, long, default_value = "17890")] #[clap(short = 'p', long, env, default_value = "17890")]
listen_port: u16, listen_port: u16,
#[clap(short, long, env, default_value = "http://192.168.1.1/cgi-bin/luci")]
luci_url: String,
#[clap(short = 'u', long, env, default_value = "root")]
luci_username: String,
#[clap(short = 'P', long, env, default_value = "123456")]
luci_password: String,
} }
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
dotenvy::dotenv().ok();
env_logger::init(); env_logger::init();
println!("Hello, world!"); println!("Hello, world!");
let args = Args::parse(); let args = Args::parse();
tokio::spawn(async move {
// Use infinite retry mechanism to ensure WAN traffic monitoring is always available
retry_forever(|| async {
let health_monitor = HealthMonitor::global().await;
health_monitor.record_attempt(ServiceType::WanPolling).await;
match poll_wan_traffic(
args.luci_url.as_str(),
args.luci_username.as_str(),
args.luci_password.as_str(),
)
.await
{
Ok(_) => {
info!("WAN traffic polling ended normally, restarting...");
health_monitor.record_success(ServiceType::WanPolling).await;
Ok(())
}
Err(err) => {
error!("WAN traffic polling failed: {}", err);
health_monitor.record_failure(ServiceType::WanPolling).await;
Err(err)
}
}
}).await;
});
let connect_addr = args.clash_url; let connect_addr = args.clash_url;
loop { // Use infinite retry mechanism to ensure WebSocket connection is always available
pipe(connect_addr.clone()).await; retry_forever(|| async {
sleep(Duration::from_secs(1)).await; let health_monitor = HealthMonitor::global().await;
info!("restart!"); health_monitor.record_attempt(ServiceType::WebSocket).await;
}
}
async fn pipe(connect_addr: String) { match pipe(connect_addr.clone()).await {
let url = url::Url::parse(&connect_addr).unwrap(); Ok(_) => {
info!("WebSocket connection ended normally, reconnecting...");
let (ws_stream, _) = connect_async(url).await.expect("Failed to connect"); health_monitor.record_success(ServiceType::WebSocket).await;
println!("WebSocket handshake has been successfully completed"); Ok(())
let (_, read) = ws_stream.split();
let ws_to_stdout = {
read.for_each(|message| async {
let data = message.unwrap().into_data();
let wrapper = serde_json::from_slice::<clash_conn_msg::ClashConnectionsWrapper>(&data);
if let Err(err) = wrapper {
stdout()
.write_all(format!("Error: {}\n", err).as_bytes())
.await
.unwrap();
return;
} }
let wrapper = wrapper.unwrap(); Err(err) => {
error!("WebSocket connection failed: {}", err);
let statistics_manager = statistics::StatisticsManager::global().await; health_monitor.record_failure(ServiceType::WebSocket).await;
Err(err)
statistics_manager.update(&wrapper).await; }
}
let state = statistics_manager.get_state().await; }).await;
}
// stdout()
// .write_all( async fn pipe(connect_addr: String) -> anyhow::Result<()> {
// format!( info!("Attempting to connect to WebSocket: {}", connect_addr);
// "len: {},speed_upload: {}, speed_download: {}, update_direct_speed: {}, download_direct_speed: {}, update_proxy_speed: {}, download_proxy_speed: {}\n",
// state.connections, let request = connect_addr.into_client_request()?;
// state.speed_upload,
// state.speed_download, // Add connection timeout
// state.direct_upload_speed, let (mut ws_stream, _) = timeout(Duration::from_secs(30), connect_async(request))
// state.direct_download_speed, .await
// state.proxy_upload_speed, .map_err(|_| anyhow!("WebSocket connection timeout after 30 seconds"))?
// state.proxy_download_speed, .map_err(|err| anyhow!("WebSocket connection failed: {}", err))?;
// )
// .as_bytes(), info!("WebSocket handshake completed successfully");
// )
// .await let heartbeat_timeout = Duration::from_secs(60); // Consider connection abnormal if no message for 60 seconds
// .unwrap();
loop {
let udp_server = udp_server::UdpServer::global().await; let message_result = timeout(heartbeat_timeout, ws_stream.next()).await;
let mut buf = [0; 32];
buf[0..8].copy_from_slice(&state.direct_upload_speed.to_le_bytes()); let message = match message_result {
buf[8..16].copy_from_slice(&state.direct_download_speed.to_le_bytes()); Ok(Some(Ok(msg))) => msg,
buf[16..24].copy_from_slice(&state.proxy_upload_speed.to_le_bytes()); Ok(Some(Err(err))) => {
buf[24..32].copy_from_slice(&state.proxy_download_speed.to_le_bytes()); error!("WebSocket message error: {}", err);
udp_server.publish(&buf).await; return Err(anyhow!("WebSocket message error: {}", err));
}) }
}; Ok(None) => {
warn!("WebSocket stream ended unexpectedly");
pin_mut!(ws_to_stdout); return Err(anyhow!("WebSocket stream ended"));
ws_to_stdout.await; }
Err(_) => {
error!("WebSocket heartbeat timeout - no message received in {} seconds", heartbeat_timeout.as_secs());
return Err(anyhow!("WebSocket heartbeat timeout"));
}
};
match message {
Message::Text(text) => {
let data = text.as_bytes();
let wrapper =
serde_json::from_slice::<clash_conn_msg::ClashConnectionsWrapper>(&data);
if let Err(err) = wrapper {
error!("parse message failed. {}", err);
ws_stream
.close(Some(CloseFrame {
code: CloseCode::Unsupported,
reason: "parse message failed".into(),
}))
.await
.map_err(|err| anyhow!(err))?;
bail!(err);
}
let wrapper = wrapper.unwrap();
let statistics_manager = statistics::StatisticsManager::global().await;
statistics_manager.update(&wrapper).await;
let state = statistics_manager.get_state().await;
debug!("len: {},speed_upload: {}, speed_download: {}, update_direct_speed: {}, download_direct_speed: {}, update_proxy_speed: {}, download_proxy_speed: {}",
state.connections,
state.speed_upload,
state.speed_download,
state.direct_upload_speed,
state.direct_download_speed,
state.proxy_upload_speed,
state.proxy_download_speed,
);
let udp_server = udp_server::UdpServer::global().await;
let mut buf = [0; 32];
buf[0..8].copy_from_slice(&state.direct_upload_speed.to_le_bytes());
buf[8..16].copy_from_slice(&state.direct_download_speed.to_le_bytes());
buf[16..24].copy_from_slice(&state.proxy_upload_speed.to_le_bytes());
buf[24..32].copy_from_slice(&state.proxy_download_speed.to_le_bytes());
udp_server.publish_clash(&buf).await;
}
Message::Close(close_frame) => {
if let Some(frame) = close_frame {
info!("Server requested close: {} - {}", frame.code, frame.reason);
} else {
info!("Server requested close without reason");
}
return Err(anyhow!("WebSocket connection closed by server"));
}
Message::Ping(payload) => {
debug!("Received ping, sending pong");
// Automatically reply pong to keep connection alive
if let Err(err) = ws_stream.send(Message::Pong(payload)).await {
error!("Failed to send pong: {}", err);
return Err(anyhow!("Failed to send pong: {}", err));
}
}
Message::Pong(_) => {
debug!("Received pong");
}
_ => {}
}
}
Ok(())
} }

261
src/retry.rs Normal file
View File

@@ -0,0 +1,261 @@
use std::time::Duration;
use log::{debug, warn, error};
use tokio::time::sleep;
/// Retry strategy configuration
#[derive(Clone, Debug)]
pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_attempts: usize,
/// Initial retry delay
pub initial_delay: Duration,
/// Maximum retry delay
pub max_delay: Duration,
/// Exponential backoff multiplier
pub backoff_multiplier: f64,
/// Whether to enable jitter
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 10,
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryConfig {
/// Create fast retry configuration (for lightweight operations)
pub fn fast() -> Self {
Self {
max_attempts: 5,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
backoff_multiplier: 1.5,
jitter: true,
}
}
/// Create slow retry configuration (for heavyweight operations)
pub fn slow() -> Self {
Self {
max_attempts: 15,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
backoff_multiplier: 2.0,
jitter: true,
}
}
/// Create infinite retry configuration (for critical services)
pub fn infinite() -> Self {
Self {
max_attempts: usize::MAX,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
/// Retrier implementation
pub struct Retrier {
config: RetryConfig,
attempt: usize,
current_delay: Duration,
}
impl Retrier {
pub fn new(config: RetryConfig) -> Self {
Self {
current_delay: config.initial_delay,
config,
attempt: 0,
}
}
/// Execute retry operation
pub async fn retry<F, Fut, T, E>(&mut self, operation: F) -> Result<T, E>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
loop {
self.attempt += 1;
debug!("Attempting operation (attempt {}/{})", self.attempt, self.config.max_attempts);
match operation().await {
Ok(result) => {
if self.attempt > 1 {
debug!("Operation succeeded after {} attempts", self.attempt);
}
return Ok(result);
}
Err(err) => {
if self.attempt >= self.config.max_attempts {
error!("Operation failed after {} attempts: {}", self.attempt, err);
return Err(err);
}
warn!("Operation failed (attempt {}/{}): {}", self.attempt, self.config.max_attempts, err);
let delay = self.calculate_delay();
debug!("Retrying in {:?}", delay);
sleep(delay).await;
}
}
}
}
/// Reset retrier state
pub fn reset(&mut self) {
self.attempt = 0;
self.current_delay = self.config.initial_delay;
}
/// Calculate next retry delay
fn calculate_delay(&mut self) -> Duration {
let delay = self.current_delay;
// Calculate next delay time (exponential backoff)
let next_delay_ms = (self.current_delay.as_millis() as f64 * self.config.backoff_multiplier) as u64;
self.current_delay = Duration::from_millis(next_delay_ms).min(self.config.max_delay);
// Add jitter to avoid thundering herd effect
if self.config.jitter {
let jitter_range = delay.as_millis() as f64 * 0.1; // 10% jitter
let jitter = (rand::random::<f64>() - 0.5) * 2.0 * jitter_range;
let jittered_delay = (delay.as_millis() as f64 + jitter).max(0.0) as u64;
Duration::from_millis(jittered_delay)
} else {
delay
}
}
/// Get current attempt count
pub fn attempt_count(&self) -> usize {
self.attempt
}
}
/// Convenience function: execute operation with retry
pub async fn retry_with_config<F, Fut, T, E>(
config: RetryConfig,
operation: F,
) -> Result<T, E>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let mut retrier = Retrier::new(config);
retrier.retry(operation).await
}
/// Convenience function: execute retry operation with default configuration
pub async fn retry<F, Fut, T, E>(operation: F) -> Result<T, E>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
retry_with_config(RetryConfig::default(), operation).await
}
/// Convenience function: infinite retry (for critical services)
pub async fn retry_forever<F, Fut, T, E>(operation: F) -> T
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let mut retrier = Retrier::new(RetryConfig::infinite());
loop {
match retrier.retry(&operation).await {
Ok(result) => return result,
Err(_) => {
// This should theoretically never happen since we set infinite retry
// But for safety, we reset the retrier and continue
retrier.reset();
continue;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_success_on_first_attempt() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = retry(|| async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok::<i32, &str>(42)
}).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_failures() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
backoff_multiplier: 2.0,
jitter: false,
};
let result = retry_with_config(config, || async {
let count = counter_clone.fetch_add(1, Ordering::SeqCst) + 1;
if count < 3 {
Err("not ready")
} else {
Ok(42)
}
}).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_max_attempts_exceeded() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let config = RetryConfig {
max_attempts: 2,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
backoff_multiplier: 2.0,
jitter: false,
};
let result = retry_with_config(config, || async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Err::<i32, &str>("always fails")
}).await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
}

View File

@@ -4,13 +4,13 @@ use std::{
sync::Arc, sync::Arc,
}; };
use clap::Parser;
use log::{error, info}; use log::{error, info};
use tokio::{ use tokio::{
net::UdpSocket, net::UdpSocket,
sync::{Mutex, OnceCell}, sync::{Mutex, OnceCell},
time::Instant, time::Instant,
}; };
use clap::Parser;
use crate::Args; use crate::Args;
@@ -24,6 +24,7 @@ pub struct UdpServer {
listen_addr: Ipv4Addr, listen_addr: Ipv4Addr,
listen_port: u16, listen_port: u16,
clients: Arc<Mutex<HashMap<SocketAddr, Client>>>, clients: Arc<Mutex<HashMap<SocketAddr, Client>>>,
tx_buffer: Arc<Mutex<[u8; 48]>>,
} }
impl UdpServer { impl UdpServer {
@@ -54,6 +55,7 @@ impl UdpServer {
listen_addr, listen_addr,
listen_port, listen_port,
clients: Arc::new(Mutex::new(HashMap::new())), clients: Arc::new(Mutex::new(HashMap::new())),
tx_buffer: Arc::new(Mutex::new([0u8; 48])),
} }
} }
@@ -95,7 +97,25 @@ impl UdpServer {
} }
} }
pub async fn publish(&self, buf: &[u8]) { pub async fn publish_clash(&self, buf: &[u8]) {
let mut tx_buffer = self.tx_buffer.lock().await;
tx_buffer[..32].copy_from_slice(buf);
let buf = tx_buffer.clone();
drop(tx_buffer);
self.publish(&buf).await;
}
pub async fn publish_wan(&self, buf: &[u8]) {
let mut tx_buffer = self.tx_buffer.lock().await;
tx_buffer[32..].copy_from_slice(buf);
let buf = tx_buffer.clone();
drop(tx_buffer);
self.publish(&buf).await;
}
async fn publish(&self, buf: &[u8]) {
let mut to_remove = Vec::new(); let mut to_remove = Vec::new();
let mut clients = self.clients.lock().await; let mut clients = self.clients.lock().await;
@@ -103,7 +123,7 @@ impl UdpServer {
if client.last_seen.elapsed().as_secs() > 10 { if client.last_seen.elapsed().as_secs() > 10 {
to_remove.push(addr.clone()); to_remove.push(addr.clone());
} else { } else {
if let Err(err) = client.socket.send(buf).await { if let Err(err) = client.socket.send(&buf).await {
error!("Failed to send data to {}: {}", addr, err); error!("Failed to send data to {}: {}", addr, err);
} else { } else {
info!("Sent data to {}", addr); info!("Sent data to {}", addr);

265
src/wan.rs Normal file
View File

@@ -0,0 +1,265 @@
use std::time::Duration;
use anyhow::anyhow;
use anyhow::bail;
use anyhow::Result;
use http_body_util::BodyExt;
use http_body_util::Empty;
use hyper::StatusCode;
use hyper::{
body::{Buf, Bytes},
Request,
};
use hyper_util::rt::TokioIo;
use log::{debug, error, info, warn};
use serde::{Deserialize, Serialize};
use tokio::net::TcpStream;
use tokio::time::{sleep, timeout};
use crate::udp_server;
pub type OpenWRTIFaces = Vec<OpenWRTIFace>;
pub async fn poll_wan_traffic(luci_url: &str, username: &str, password: &str) -> Result<()> {
info!("Starting WAN traffic polling for: {}", luci_url);
// Parse our URL...
let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::<hyper::Uri>()?;
// Get the host and the port
let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(80);
let address = format!("{}:{}", host, port);
info!("Connecting to: {}", address);
let mut cookies_str = String::new();
let connection_timeout = Duration::from_secs(10);
let request_timeout = Duration::from_secs(30);
// Open a TCP connection to the remote host with timeout
loop {
let stream = timeout(connection_timeout, TcpStream::connect(address.clone()))
.await
.map_err(|_| anyhow!("Connection timeout after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| anyhow!("Failed to connect to {}: {}", address, e))?;
// Use an adapter to access something implementing `tokio::io` traits as if they implement
// `hyper::rt` IO traits.
let io = TokioIo::new(stream);
// Create the Hyper client with timeout
let (mut sender, conn) = timeout(connection_timeout, hyper::client::conn::http1::handshake(io))
.await
.map_err(|_| anyhow!("HTTP handshake timeout after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| anyhow!("HTTP handshake failed: {}", e))?;
// Spawn a task to poll the connection, driving the HTTP state
let conn_handle = tokio::task::spawn(async move {
if let Err(err) = conn.await {
error!("HTTP connection failed: {:?}", err);
}
});
let mut prev_up = 0;
let mut prev_down = 0;
let mut consecutive_errors = 0;
const MAX_CONSECUTIVE_ERRORS: u32 = 5;
loop {
if sender.is_closed() {
warn!("HTTP sender is closed, breaking connection loop");
break;
}
if !sender.is_ready() {
sleep(Duration::from_millis(100)).await;
continue;
}
let poll_req = generate_poll_wan_req(luci_url, &cookies_str)?;
// Send request with timeout
let res = timeout(request_timeout, sender.send_request(poll_req.clone()))
.await
.map_err(|_| anyhow!("Request timeout after {} seconds", request_timeout.as_secs()))?
.map_err(|e| anyhow!("Request failed: {}", e))?;
debug!("Response status: {}", res.status());
match res.status() {
StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => {
info!("Authentication required, logging in...");
cookies_str = login(luci_url, username, password).await?;
consecutive_errors = 0; // Reset error count
continue;
}
StatusCode::OK => {
consecutive_errors = 0; // Reset error count
}
status => {
consecutive_errors += 1;
error!("Unexpected response status: {} (consecutive errors: {})", status, consecutive_errors);
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
return Err(anyhow!("Too many consecutive errors ({}), giving up", consecutive_errors));
}
sleep(Duration::from_secs(1)).await;
continue;
}
}
// asynchronously aggregate the chunks of the body with timeout
let body = timeout(request_timeout, res.collect())
.await
.map_err(|_| anyhow!("Response body timeout after {} seconds", request_timeout.as_secs()))?
.map_err(|e| anyhow!("Failed to read response body: {}", e))?
.aggregate();
// try to parse as json with serde_json
let interfaces: OpenWRTIFaces = serde_json::from_reader(body.reader())
.map_err(|e| anyhow!("Failed to parse JSON response: {}", e))?;
let wan = interfaces.first()
.ok_or_else(|| anyhow!("No WAN interface found in response"))?;
// Detect counter reset (restart scenarios)
if prev_up > wan.tx_bytes || prev_down > wan.rx_bytes {
info!("Counter reset detected, resetting baseline values");
prev_up = wan.tx_bytes;
prev_down = wan.rx_bytes;
}
let up_speed = (wan.tx_bytes - prev_up) as u64;
let down_speed = (wan.rx_bytes - prev_down) as u64;
prev_up = wan.tx_bytes;
prev_down = wan.rx_bytes;
debug!("WAN traffic: ↑ {} bytes/s, ↓ {} bytes/s", up_speed, down_speed);
let udp_server = udp_server::UdpServer::global().await;
let mut buf = [0; 16];
buf[0..8].copy_from_slice(&up_speed.to_le_bytes());
buf[8..16].copy_from_slice(&down_speed.to_le_bytes());
udp_server.publish_wan(&buf).await;
sleep(Duration::from_secs(1)).await;
}
}
}
async fn login(luci_url: &str, username: &str, password: &str) -> Result<String> {
info!("Attempting to login to: {}", luci_url);
let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::<hyper::Uri>()?;
let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(80);
let address = format!("{}:{}", host, port);
let connection_timeout = Duration::from_secs(10);
let request_timeout = Duration::from_secs(30);
let stream = timeout(connection_timeout, TcpStream::connect(address.clone()))
.await
.map_err(|_| anyhow!("Login connection timeout after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| anyhow!("Failed to connect for login to {}: {}", address, e))?;
let io = TokioIo::new(stream);
let (mut sender, conn) = timeout(connection_timeout, hyper::client::conn::http1::handshake(io))
.await
.map_err(|_| anyhow!("Login handshake timeout after {} seconds", connection_timeout.as_secs()))?
.map_err(|e| anyhow!("Login handshake failed: {}", e))?;
let _conn_handle = tokio::task::spawn(async move {
if let Err(err) = conn.await {
error!("Login connection failed: {:?}", err);
}
});
let login_req = generate_login_req(luci_url, username, password)?;
let res = timeout(request_timeout, sender.send_request(login_req.clone()))
.await
.map_err(|_| anyhow!("Login request timeout after {} seconds", request_timeout.as_secs()))?
.map_err(|e| anyhow!("Login request failed: {}", e))?;
match res.status() {
StatusCode::OK | StatusCode::FOUND | StatusCode::SEE_OTHER => {
info!("Login successful");
}
StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => {
error!("Login failed: Invalid credentials");
bail!("Login failed with status: {}", res.status());
}
status => {
error!("Login failed with unexpected status: {}", status);
bail!("Login failed with status: {}", status);
}
}
let cookies = res.headers().get_all(hyper::header::SET_COOKIE);
let cookies = cookies
.iter()
.filter_map(|cookie| cookie.to_str().ok())
.filter_map(|cookie| cookie.split(';').next())
.collect::<Vec<_>>();
if cookies.is_empty() {
warn!("No cookies received from login response");
}
let cookies_str = cookies.join("; ");
debug!("Received cookies: {}", cookies_str);
Ok(cookies_str)
}
fn generate_poll_wan_req(luci_url: &str, cookie: &str) -> Result<Request<Empty<Bytes>>> {
let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::<hyper::Uri>()?;
let authority = url.authority().unwrap().clone();
let target = url.path_and_query().unwrap().clone();
debug!("Polling WAN traffic. Cookie: {:?}", cookie);
return Request::builder()
.uri(target)
.header(hyper::header::HOST, authority.as_str())
.header(hyper::header::CONNECTION, "Keep-Alive")
.header(hyper::header::COOKIE, cookie)
.body(Empty::<Bytes>::new())
.map_err(|e| anyhow!(e));
}
fn generate_login_req(luci_url: &str, username: &str, password: &str) -> Result<Request<String>> {
let url = format!("{}/admin/network/iface_status/wan", luci_url).parse::<hyper::Uri>()?;
let authority = url.authority().unwrap().clone();
let target = url.path_and_query().unwrap().clone();
let login_params = format!("luci_username={}&luci_password={}", username, password);
let body = login_params;
return Request::builder()
.uri(target)
.method("POST")
.header(hyper::header::HOST, authority.as_str())
.header(
hyper::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.header(hyper::header::CONNECTION, "Keep-Alive")
.body(body)
.map_err(|e| anyhow!(e));
}
#[derive(Serialize, Deserialize)]
pub struct OpenWRTIFace {
rx_bytes: u64,
tx_bytes: u64,
}
impl Default for OpenWRTIFace {
fn default() -> Self {
Self {
rx_bytes: 0,
tx_bytes: 0,
}
}
}

148
tests/integration_test.rs Normal file
View File

@@ -0,0 +1,148 @@
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use network_monitor::retry::{RetryConfig, retry_with_config};
/// Test simulating network failure recovery
#[tokio::test]
async fn test_network_failure_recovery() {
// Simulate an operation that fails a few times then succeeds
let attempt_count = Arc::new(AtomicU32::new(0));
let max_failures = 3;
let config = RetryConfig {
max_attempts: 10,
initial_delay: Duration::from_millis(10),
max_delay: Duration::from_millis(100),
backoff_multiplier: 1.5,
jitter: false, // Disable jitter for more predictable testing
};
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
if current_attempt <= max_failures {
// Simulate network error
Err(format!("Network error on attempt {}", current_attempt))
} else {
// Simulate successful recovery
Ok(format!("Success on attempt {}", current_attempt))
}
}
}).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Success on attempt 4");
assert_eq!(attempt_count.load(Ordering::SeqCst), 4);
}
/// Test connection timeout scenario
#[tokio::test]
async fn test_connection_timeout_scenario() {
let config = RetryConfig {
max_attempts: 3,
initial_delay: Duration::from_millis(5),
max_delay: Duration::from_millis(20),
backoff_multiplier: 2.0,
jitter: false,
};
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result: Result<String, &str> = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
attempt_count.fetch_add(1, Ordering::SeqCst);
// Simulate connection timeout
sleep(Duration::from_millis(1)).await;
Err("Connection timeout")
}
}).await;
assert!(result.is_err());
assert_eq!(attempt_count.load(Ordering::SeqCst), 3); // Should have attempted 3 times
}
/// Test fast recovery scenario
#[tokio::test]
async fn test_fast_recovery() {
let config = RetryConfig::fast();
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
if current_attempt == 1 {
Err("Temporary failure")
} else {
Ok("Quick recovery")
}
}
}).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Quick recovery");
assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
}
/// Test slow retry scenario
#[tokio::test]
async fn test_slow_retry_scenario() {
let config = RetryConfig::slow();
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
let current_attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
if current_attempt <= 2 {
Err("Service unavailable")
} else {
Ok("Service restored")
}
}
}).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Service restored");
assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
}
/// Test maximum retry limit
#[tokio::test]
async fn test_max_retry_limit() {
let config = RetryConfig {
max_attempts: 2,
initial_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(5),
backoff_multiplier: 2.0,
jitter: false,
};
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = retry_with_config(config, move || {
let attempt_count = attempt_count_clone.clone();
async move {
attempt_count.fetch_add(1, Ordering::SeqCst);
Err::<String, &str>("Persistent failure")
}
}).await;
assert!(result.is_err());
assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
}