Pull the websocket thread code into a separate function

This commit is contained in:
Mike Dilger 2024-07-13 13:03:41 +12:00
parent ad7e70e3a0
commit c1ee86f4df

View File

@ -21,7 +21,7 @@ use hyper::upgrade::Upgraded;
use hyper::StatusCode; use hyper::StatusCode;
use hyper::{Request, Response}; use hyper::{Request, Response};
use hyper_tungstenite::tungstenite; use hyper_tungstenite::tungstenite;
use hyper_tungstenite::WebSocketStream; use hyper_tungstenite::{HyperWebsocket, WebSocketStream};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use pocket_db::Store; use pocket_db::Store;
use pocket_types::{Id, OwnedFilter, Pubkey}; use pocket_types::{Id, OwnedFilter, Pubkey};
@ -161,88 +161,77 @@ async fn handle_http_request(
max_frame_size: Some(1024 * 1024), // 1 MB max_frame_size: Some(1024 * 1024), // 1 MB
..Default::default() ..Default::default()
}; };
let (response, websocket) = let (response, websocket) =
hyper_tungstenite::upgrade(&mut request, Some(web_socket_config))?; hyper_tungstenite::upgrade(&mut request, Some(web_socket_config))?;
tokio::spawn(async move {
// Await the websocket upgrade process
match websocket.await {
Ok(websocket) => {
// Build a websocket service
let mut ws_service = WebSocketService {
peer,
subscriptions: HashMap::new(),
// We start with a 1-page buffer, and grow it if needed.
buffer: vec![0; 4096],
websocket,
challenge: TextNonce::new().into_string(),
user: None,
error_punishment: 0.0,
replied: false,
};
// Increment connection count // Start the websocket thread
let old_num_websockets = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst); tokio::spawn(async move { websocket_thread(peer, websocket, origin, ua).await });
// Increment per-ip connection count Ok(response)
GLOBALS } else {
.num_connections_per_ip web::serve_http(peer, request).await
.entry(peer.ip()) }
.and_modify(|count| *count += 1) }
.or_insert(1);
// we cheat somewhat and log these websocket open and close messages async fn websocket_thread(peer: HashedPeer, websocket: HyperWebsocket, origin: String, ua: String) {
// as server messages // Await the websocket upgrade process
log::info!( match websocket.await {
target: "Server", Ok(websocket) => {
"{}: TOTAL={}, New Connection: {}, {}", // Build a websocket service
peer, let mut ws_service = WebSocketService {
old_num_websockets + 1, peer,
origin, subscriptions: HashMap::new(),
ua, // We start with a 1-page buffer, and grow it if needed.
); buffer: vec![0; 4096],
websocket,
challenge: TextNonce::new().into_string(),
user: None,
error_punishment: 0.0,
replied: false,
};
// Everybody gets a ban on disconnect to prevent rapid reconnection // Increment connection count
let mut session_exit: SessionExit = SessionExit::Ok; let old_num_websockets = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst);
let mut msg = "Closed";
// Handle the websocket // Increment per-ip connection count
if let Err(e) = ws_service.handle_websocket_stream().await { GLOBALS
match e.inner { .num_connections_per_ip
ChorusError::Tungstenite(tungstenite::error::Error::Protocol( .entry(peer.ip())
tungstenite::error::ProtocolError::ResetWithoutClosingHandshake, .and_modify(|count| *count += 1)
)) => { .or_insert(1);
// So they disconnected ungracefully.
// No big deal, still SessionExit::Ok // we cheat somewhat and log these websocket open and close messages
msg = "Reset"; // as server messages
} log::info!(
ChorusError::Tungstenite(tungstenite::error::Error::Io( target: "Server",
ref ioerror, "{}: TOTAL={}, New Connection: {}, {}",
)) => { peer,
match ioerror.kind() { old_num_websockets + 1,
std::io::ErrorKind::ConnectionReset origin,
| std::io::ErrorKind::ConnectionAborted ua,
| std::io::ErrorKind::UnexpectedEof => { );
// no biggie.
msg = "Reset"; // Everybody gets a ban on disconnect to prevent rapid reconnection
} let mut session_exit: SessionExit = SessionExit::Ok;
_ => { let mut msg = "Closed";
log::error!(target: "Client", "{}: {}", peer, e);
session_exit = SessionExit::ErrorExit; // Handle the websocket
msg = "Error Exited"; if let Err(e) = ws_service.handle_websocket_stream().await {
} match e.inner {
} ChorusError::Tungstenite(tungstenite::error::Error::Protocol(
} tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
ChorusError::ErrorClose => { )) => {
session_exit = SessionExit::TooManyErrors; // So they disconnected ungracefully.
msg = "Errored Out"; // No big deal, still SessionExit::Ok
} msg = "Reset";
ChorusError::TimedOut => { }
session_exit = SessionExit::Timeout; ChorusError::Tungstenite(tungstenite::error::Error::Io(ref ioerror)) => {
msg = "Timed Out (with no subscriptions)"; match ioerror.kind() {
} std::io::ErrorKind::ConnectionReset
ChorusError::Io(_) => { | std::io::ErrorKind::ConnectionAborted
// Usually "Connection reset by peer" but any I/O error | std::io::ErrorKind::UnexpectedEof => {
// isn't a big deal. // no biggie.
msg = "Reset"; msg = "Reset";
} }
_ => { _ => {
@ -252,51 +241,65 @@ async fn handle_http_request(
} }
} }
} }
ChorusError::ErrorClose => {
// Decrement count of active websockets session_exit = SessionExit::TooManyErrors;
let old_num_websockets = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst); msg = "Errored Out";
}
// Decrement per-ip connection count ChorusError::TimedOut => {
match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) { session_exit = SessionExit::Timeout;
Some(mut refmut) => { msg = "Timed Out (with no subscriptions)";
if *refmut.value_mut() > 0 { }
*refmut.value_mut() -= 1; ChorusError::Io(_) => {
} else { // Usually "Connection reset by peer" but any I/O error
unreachable!("The connection should be in the map") // isn't a big deal.
} msg = "Reset";
} }
None => unreachable!("The connection count should be greater than zero"), _ => {
}; log::error!(target: "Client", "{}: {}", peer, e);
session_exit = SessionExit::ErrorExit;
// Update ip data (including ban time) msg = "Error Exited";
// if GLOBALS.config.read().enable_ip_blocking {
let mut ban_seconds = 0;
let minimum_ban_seconds = GLOBALS.config.read().minimum_ban_seconds;
if let Ok(mut ip_data) = get_ip_data(GLOBALS.store.get().unwrap(), peer.ip()) {
ban_seconds =
ip_data.update_on_session_close(session_exit, minimum_ban_seconds);
let _ = update_ip_data(GLOBALS.store.get().unwrap(), peer.ip(), &ip_data);
} }
// we cheat somewhat and log these websocket open and close messages
// as server messages
log::info!(
target: "Server",
"{}: TOTAL={}, {}, ban={}s",
peer,
old_num_websockets - 1,
msg,
ban_seconds
);
}
Err(e) => {
log::error!(target: "Client", "{}: {}", peer, e);
} }
} }
});
Ok(response) // Decrement connection count
} else { let old_num_websockets = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst);
web::serve_http(peer, request).await
// Decrement per-ip connection count
match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) {
Some(mut refmut) => {
if *refmut.value_mut() > 0 {
*refmut.value_mut() -= 1;
} else {
unreachable!("The connection should be in the map")
}
}
None => unreachable!("The connection count should be greater than zero"),
};
// Update ip data (including ban time)
// if GLOBALS.config.read().enable_ip_blocking {
let mut ban_seconds = 0;
let minimum_ban_seconds = GLOBALS.config.read().minimum_ban_seconds;
if let Ok(mut ip_data) = get_ip_data(GLOBALS.store.get().unwrap(), peer.ip()) {
ban_seconds = ip_data.update_on_session_close(session_exit, minimum_ban_seconds);
let _ = update_ip_data(GLOBALS.store.get().unwrap(), peer.ip(), &ip_data);
}
// we cheat somewhat and log these websocket open and close messages
// as server messages
log::info!(
target: "Server",
"{}: TOTAL={}, {}, ban={}s",
peer,
old_num_websockets - 1,
msg,
ban_seconds
);
}
Err(e) => {
log::error!(target: "Client", "{}: {}", peer, e);
}
} }
} }