From c1ee86f4df630b27b5769113c650109d754ece29 Mon Sep 17 00:00:00 2001 From: Mike Dilger Date: Sat, 13 Jul 2024 13:03:41 +1200 Subject: [PATCH] Pull the websocket thread code into a separate function --- src/lib.rs | 239 +++++++++++++++++++++++++++-------------------------- 1 file changed, 121 insertions(+), 118 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 781b461..a23f8b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ use hyper::upgrade::Upgraded; use hyper::StatusCode; use hyper::{Request, Response}; use hyper_tungstenite::tungstenite; -use hyper_tungstenite::WebSocketStream; +use hyper_tungstenite::{HyperWebsocket, WebSocketStream}; use hyper_util::rt::TokioIo; use pocket_db::Store; use pocket_types::{Id, OwnedFilter, Pubkey}; @@ -161,88 +161,77 @@ async fn handle_http_request( max_frame_size: Some(1024 * 1024), // 1 MB ..Default::default() }; + let (response, websocket) = hyper_tungstenite::upgrade(&mut request, Some(web_socket_config))?; - tokio::spawn(async move { - // Await the websocket upgrade process - match websocket.await { - Ok(websocket) => { - // Build a websocket service - let mut ws_service = WebSocketService { - peer, - subscriptions: HashMap::new(), - // We start with a 1-page buffer, and grow it if needed. - buffer: vec![0; 4096], - websocket, - challenge: TextNonce::new().into_string(), - user: None, - error_punishment: 0.0, - replied: false, - }; - // Increment connection count - let old_num_websockets = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst); + // Start the websocket thread + tokio::spawn(async move { websocket_thread(peer, websocket, origin, ua).await }); - // Increment per-ip connection count - GLOBALS - .num_connections_per_ip - .entry(peer.ip()) - .and_modify(|count| *count += 1) - .or_insert(1); + Ok(response) + } else { + web::serve_http(peer, request).await + } +} - // we cheat somewhat and log these websocket open and close messages - // as server messages - log::info!( - target: "Server", - "{}: TOTAL={}, New Connection: {}, {}", - peer, - old_num_websockets + 1, - origin, - ua, - ); +async fn websocket_thread(peer: HashedPeer, websocket: HyperWebsocket, origin: String, ua: String) { + // Await the websocket upgrade process + match websocket.await { + Ok(websocket) => { + // Build a websocket service + let mut ws_service = WebSocketService { + peer, + subscriptions: HashMap::new(), + // We start with a 1-page buffer, and grow it if needed. + buffer: vec![0; 4096], + websocket, + challenge: TextNonce::new().into_string(), + user: None, + error_punishment: 0.0, + replied: false, + }; - // Everybody gets a ban on disconnect to prevent rapid reconnection - let mut session_exit: SessionExit = SessionExit::Ok; - let mut msg = "Closed"; + // Increment connection count + let old_num_websockets = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst); - // Handle the websocket - if let Err(e) = ws_service.handle_websocket_stream().await { - match e.inner { - ChorusError::Tungstenite(tungstenite::error::Error::Protocol( - tungstenite::error::ProtocolError::ResetWithoutClosingHandshake, - )) => { - // So they disconnected ungracefully. - // No big deal, still SessionExit::Ok - msg = "Reset"; - } - ChorusError::Tungstenite(tungstenite::error::Error::Io( - ref ioerror, - )) => { - match ioerror.kind() { - std::io::ErrorKind::ConnectionReset - | std::io::ErrorKind::ConnectionAborted - | std::io::ErrorKind::UnexpectedEof => { - // no biggie. - msg = "Reset"; - } - _ => { - log::error!(target: "Client", "{}: {}", peer, e); - session_exit = SessionExit::ErrorExit; - msg = "Error Exited"; - } - } - } - ChorusError::ErrorClose => { - session_exit = SessionExit::TooManyErrors; - msg = "Errored Out"; - } - ChorusError::TimedOut => { - session_exit = SessionExit::Timeout; - msg = "Timed Out (with no subscriptions)"; - } - ChorusError::Io(_) => { - // Usually "Connection reset by peer" but any I/O error - // isn't a big deal. + // Increment per-ip connection count + GLOBALS + .num_connections_per_ip + .entry(peer.ip()) + .and_modify(|count| *count += 1) + .or_insert(1); + + // we cheat somewhat and log these websocket open and close messages + // as server messages + log::info!( + target: "Server", + "{}: TOTAL={}, New Connection: {}, {}", + peer, + old_num_websockets + 1, + origin, + ua, + ); + + // Everybody gets a ban on disconnect to prevent rapid reconnection + let mut session_exit: SessionExit = SessionExit::Ok; + let mut msg = "Closed"; + + // Handle the websocket + if let Err(e) = ws_service.handle_websocket_stream().await { + match e.inner { + ChorusError::Tungstenite(tungstenite::error::Error::Protocol( + tungstenite::error::ProtocolError::ResetWithoutClosingHandshake, + )) => { + // So they disconnected ungracefully. + // No big deal, still SessionExit::Ok + msg = "Reset"; + } + ChorusError::Tungstenite(tungstenite::error::Error::Io(ref ioerror)) => { + match ioerror.kind() { + std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::UnexpectedEof => { + // no biggie. msg = "Reset"; } _ => { @@ -252,51 +241,65 @@ async fn handle_http_request( } } } - - // Decrement count of active websockets - let old_num_websockets = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst); - - // Decrement per-ip connection count - match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) { - Some(mut refmut) => { - if *refmut.value_mut() > 0 { - *refmut.value_mut() -= 1; - } else { - unreachable!("The connection should be in the map") - } - } - None => unreachable!("The connection count should be greater than zero"), - }; - - // Update ip data (including ban time) - // if GLOBALS.config.read().enable_ip_blocking { - let mut ban_seconds = 0; - let minimum_ban_seconds = GLOBALS.config.read().minimum_ban_seconds; - if let Ok(mut ip_data) = get_ip_data(GLOBALS.store.get().unwrap(), peer.ip()) { - ban_seconds = - ip_data.update_on_session_close(session_exit, minimum_ban_seconds); - let _ = update_ip_data(GLOBALS.store.get().unwrap(), peer.ip(), &ip_data); + ChorusError::ErrorClose => { + session_exit = SessionExit::TooManyErrors; + msg = "Errored Out"; + } + ChorusError::TimedOut => { + session_exit = SessionExit::Timeout; + msg = "Timed Out (with no subscriptions)"; + } + ChorusError::Io(_) => { + // Usually "Connection reset by peer" but any I/O error + // isn't a big deal. + msg = "Reset"; + } + _ => { + log::error!(target: "Client", "{}: {}", peer, e); + session_exit = SessionExit::ErrorExit; + msg = "Error Exited"; } - - // we cheat somewhat and log these websocket open and close messages - // as server messages - log::info!( - target: "Server", - "{}: TOTAL={}, {}, ban={}s", - peer, - old_num_websockets - 1, - msg, - ban_seconds - ); - } - Err(e) => { - log::error!(target: "Client", "{}: {}", peer, e); } } - }); - Ok(response) - } else { - web::serve_http(peer, request).await + + // Decrement connection count + let old_num_websockets = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst); + + // Decrement per-ip connection count + match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) { + Some(mut refmut) => { + if *refmut.value_mut() > 0 { + *refmut.value_mut() -= 1; + } else { + unreachable!("The connection should be in the map") + } + } + None => unreachable!("The connection count should be greater than zero"), + }; + + // Update ip data (including ban time) + // if GLOBALS.config.read().enable_ip_blocking { + let mut ban_seconds = 0; + let minimum_ban_seconds = GLOBALS.config.read().minimum_ban_seconds; + if let Ok(mut ip_data) = get_ip_data(GLOBALS.store.get().unwrap(), peer.ip()) { + ban_seconds = ip_data.update_on_session_close(session_exit, minimum_ban_seconds); + let _ = update_ip_data(GLOBALS.store.get().unwrap(), peer.ip(), &ip_data); + } + + // we cheat somewhat and log these websocket open and close messages + // as server messages + log::info!( + target: "Server", + "{}: TOTAL={}, {}, ban={}s", + peer, + old_num_websockets - 1, + msg, + ban_seconds + ); + } + Err(e) => { + log::error!(target: "Client", "{}: {}", peer, e); + } } }