diff --git a/src/main.rs b/src/main.rs index 92e493d..5362f57 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,14 @@ -use std::{format, time::Duration}; +use std::time::Duration; +use anyhow::{anyhow, bail}; use clap::Parser; -use futures_util::{pin_mut, StreamExt}; -use log::{error, info}; -use tokio::{ - io::{stdout, AsyncWriteExt}, - time::sleep, +use futures_util::StreamExt; +use log::{debug, error}; +use tokio::time::sleep; +use tokio_tungstenite::{ + connect_async, + tungstenite::{protocol::{frame::coding::CloseCode, CloseFrame}, Message}, }; -use tokio_tungstenite::connect_async; mod clash_conn_msg; mod statistics; @@ -80,66 +81,64 @@ async fn main() { async fn pipe(connect_addr: String) -> anyhow::Result<()> { let url = url::Url::parse(&connect_addr).map_err(|err| anyhow::anyhow!(err))?; - let (ws_stream, _) = connect_async(url) + let (mut ws_stream, _) = connect_async(url) .await .map_err(|err| anyhow::anyhow!(err))?; println!("WebSocket handshake has been successfully completed"); - let (_, read) = ws_stream.split(); + while let Some(message) = ws_stream.next().await { + let message = message?; - let ws_to_stdout = { - read.for_each(|message| async { - if let Err(err) = message { - error!("bad message. {}", err); - return; + match message { + Message::Text(text) => { + let data = text.into_bytes(); + + let wrapper = + serde_json::from_slice::(&data); + + if let Err(err) = wrapper { + error!("parse message failed. {}", err); + ws_stream + .close(Some(CloseFrame { + code: CloseCode::Unsupported, + reason: "parse message failed".into(), + })).await + .map_err(|err| anyhow!(err))?; + bail!(err); + } + let wrapper = wrapper.unwrap(); + + let statistics_manager = statistics::StatisticsManager::global().await; + + statistics_manager.update(&wrapper).await; + + let state = statistics_manager.get_state().await; + + debug!("len: {},speed_upload: {}, speed_download: {}, update_direct_speed: {}, download_direct_speed: {}, update_proxy_speed: {}, download_proxy_speed: {}", + state.connections, + state.speed_upload, + state.speed_download, + state.direct_upload_speed, + state.direct_download_speed, + state.proxy_upload_speed, + state.proxy_download_speed, + ); + + let udp_server = udp_server::UdpServer::global().await; + let mut buf = [0; 32]; + buf[0..8].copy_from_slice(&state.direct_upload_speed.to_le_bytes()); + buf[8..16].copy_from_slice(&state.direct_download_speed.to_le_bytes()); + buf[16..24].copy_from_slice(&state.proxy_upload_speed.to_le_bytes()); + buf[24..32].copy_from_slice(&state.proxy_download_speed.to_le_bytes()); + udp_server.publish_clash(&buf).await; } - - let message = message.unwrap(); - let data = message.into_data(); - - let wrapper = serde_json::from_slice::(&data); - - if let Err(err) = wrapper { - error!("parse message failed. {}", err); - return; + Message::Close(_) => { + println!("Server requested close"); + break; } - let wrapper = wrapper.unwrap(); - - let statistics_manager = statistics::StatisticsManager::global().await; - - statistics_manager.update(&wrapper).await; - - let state = statistics_manager.get_state().await; - - // stdout() - // .write_all( - // format!( - // "len: {},speed_upload: {}, speed_download: {}, update_direct_speed: {}, download_direct_speed: {}, update_proxy_speed: {}, download_proxy_speed: {}\n", - // state.connections, - // state.speed_upload, - // state.speed_download, - // state.direct_upload_speed, - // state.direct_download_speed, - // state.proxy_upload_speed, - // state.proxy_download_speed, - // ) - // .as_bytes(), - // ) - // .await - // .unwrap(); - - let udp_server = udp_server::UdpServer::global().await; - let mut buf = [0; 32]; - buf[0..8].copy_from_slice(&state.direct_upload_speed.to_le_bytes()); - buf[8..16].copy_from_slice(&state.direct_download_speed.to_le_bytes()); - buf[16..24].copy_from_slice(&state.proxy_upload_speed.to_le_bytes()); - buf[24..32].copy_from_slice(&state.proxy_download_speed.to_le_bytes()); - udp_server.publish_clash(&buf).await; - }) - }; - - pin_mut!(ws_to_stdout); - ws_to_stdout.await; + _ => {} + } + } Ok(()) }