From 2cc9902f235ccc65cecab35d3064113eb506d412 Mon Sep 17 00:00:00 2001 From: Mike Dilger Date: Sat, 25 May 2024 07:40:14 +1200 Subject: [PATCH] max_connections_per_ip setting --- contrib/chorus.toml | 9 ++++++++- docs/CONFIG.md | 6 ++++++ sample/sample.config.toml | 1 + src/bin/chorus.rs | 10 +++++----- src/bin/chorus_compress.rs | 2 +- src/config.rs | 5 +++++ src/globals.rs | 8 ++++++-- src/ip.rs | 2 +- src/lib.rs | 35 ++++++++++++++++++++++++++++++++--- 9 files changed, 65 insertions(+), 13 deletions(-) diff --git a/contrib/chorus.toml b/contrib/chorus.toml index ba4fffd..c233cbb 100644 --- a/contrib/chorus.toml +++ b/contrib/chorus.toml @@ -244,4 +244,11 @@ minimum_ban_seconds = 1 # # Default is 60 # -timeout_seconds = 60 \ No newline at end of file +timeout_seconds = 60 + + +# Maximum number of websocket connections per IP address +# +# Default is 2 +# +max_connections_per_ip = 2 \ No newline at end of file diff --git a/docs/CONFIG.md b/docs/CONFIG.md index dd1ac2d..43e4098 100644 --- a/docs/CONFIG.md +++ b/docs/CONFIG.md @@ -214,3 +214,9 @@ Default is 1 Number of seconds beyond which chorus times out a client that has no open subscriptions. Default is 60 + +### max_connections_per_ip + +Maximum number of websocket connections per IP address + +Default is 2 diff --git a/sample/sample.config.toml b/sample/sample.config.toml index 74b1b9e..41bfe26 100644 --- a/sample/sample.config.toml +++ b/sample/sample.config.toml @@ -26,3 +26,4 @@ client_log_level = "Warn" enable_ip_blocking = true minimum_ban_seconds = 1 timeout_seconds = 60 +max_connections_per_ip = 2 diff --git a/src/bin/chorus.rs b/src/bin/chorus.rs index 9963cbe..bf21ff5 100644 --- a/src/bin/chorus.rs +++ b/src/bin/chorus.rs @@ -136,15 +136,15 @@ async fn main() -> Result<(), Error> { let _ = GLOBALS.shutting_down.send(true); // Wait for active websockets to shutdown gracefully - let mut num_clients = GLOBALS.num_clients.load(Ordering::Relaxed); - if num_clients != 0 { - log::info!(target: "Server", "Waiting for {num_clients} websockets to shutdown..."); + let mut num_connections = GLOBALS.num_connections.load(Ordering::Relaxed); + if num_connections != 0 { + log::info!(target: "Server", "Waiting for {num_connections} websockets to shutdown..."); // We will check if all clients have shutdown every 25ms let interval = tokio::time::interval(Duration::from_millis(25)); tokio::pin!(interval); - while num_clients != 0 { + while num_connections != 0 { // If we get another shutdown signal, stop waiting for websockets tokio::select! { v = interrupt_signal.recv() => if v.is_some() { @@ -157,7 +157,7 @@ async fn main() -> Result<(), Error> { break; }, _instant = interval.tick() => { - num_clients = GLOBALS.num_clients.load(Ordering::Relaxed); + num_connections = GLOBALS.num_connections.load(Ordering::Relaxed); continue; } } diff --git a/src/bin/chorus_compress.rs b/src/bin/chorus_compress.rs index 94bdec5..b3d0556 100644 --- a/src/bin/chorus_compress.rs +++ b/src/bin/chorus_compress.rs @@ -17,7 +17,7 @@ fn main() -> Result<(), Error> { println!("Chorus must NOT be running when you do this."); println!("Proceed? (break out with ^C, or press to proceed)"); let stdin = std::io::stdin(); - let _ = stdin.lock().lines().next().unwrap().unwrap(); + let _ = stdin.lock().lines().next().unwrap().unwrap(); let store = chorus::setup_store(&config)?; let pre_stats = store.stats()?; diff --git a/src/config.rs b/src/config.rs index 2673844..ab9c2ee 100644 --- a/src/config.rs +++ b/src/config.rs @@ -33,6 +33,7 @@ pub struct FriendlyConfig { pub enable_ip_blocking: bool, pub minimum_ban_seconds: u64, pub timeout_seconds: u64, + pub max_connections_per_ip: usize, } impl Default for FriendlyConfig { @@ -64,6 +65,7 @@ impl Default for FriendlyConfig { enable_ip_blocking: true, minimum_ban_seconds: 1, timeout_seconds: 60, + max_connections_per_ip: 2, } } } @@ -97,6 +99,7 @@ impl FriendlyConfig { enable_ip_blocking, minimum_ban_seconds, timeout_seconds, + max_connections_per_ip, } = self; let mut public_key: Option = None; @@ -146,6 +149,7 @@ impl FriendlyConfig { enable_ip_blocking, minimum_ban_seconds, timeout_seconds, + max_connections_per_ip, }) } } @@ -179,6 +183,7 @@ pub struct Config { pub enable_ip_blocking: bool, pub minimum_ban_seconds: u64, pub timeout_seconds: u64, + pub max_connections_per_ip: usize, } impl Default for Config { diff --git a/src/globals.rs b/src/globals.rs index 97158da..065d763 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -1,4 +1,6 @@ use crate::config::Config; +use crate::ip::HashedIp; +use dashmap::DashMap; use hyper::server::conn::Http; use lazy_static::lazy_static; use parking_lot::RwLock; @@ -24,7 +26,8 @@ pub struct Globals { /// that subscription. pub new_events: BroadcastSender, - pub num_clients: AtomicUsize, + pub num_connections: AtomicUsize, + pub num_connections_per_ip: DashMap, pub shutting_down: WatchSender, } @@ -46,7 +49,8 @@ lazy_static! { http_server, rid: OnceLock::new(), new_events, - num_clients: AtomicUsize::new(0), + num_connections: AtomicUsize::new(0), + num_connections_per_ip: DashMap::new(), shutting_down, } }; diff --git a/src/ip.rs b/src/ip.rs index 172c24e..76b94af 100644 --- a/src/ip.rs +++ b/src/ip.rs @@ -2,7 +2,7 @@ use pocket_types::Time; use speedy::{Readable, Writable}; use std::net::{IpAddr, SocketAddr}; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub struct HashedIp(pub [u8; 20], bool); impl std::fmt::Display for HashedIp { diff --git a/src/lib.rs b/src/lib.rs index b19f07b..ebbd461 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ use crate::tls::MaybeTlsStream; use futures::{sink::SinkExt, stream::StreamExt}; use hyper::service::Service; use hyper::upgrade::Upgraded; +use hyper::StatusCode; use hyper::{Body, Request, Response}; use hyper_tungstenite::tungstenite; use hyper_tungstenite::WebSocketStream; @@ -121,6 +122,15 @@ async fn handle_http_request( None => "(no origin)".to_owned(), }; + let max_conn = GLOBALS.config.read().max_connections_per_ip; + if let Some(cur) = GLOBALS.num_connections_per_ip.get(&peer.ip()) { + if *cur.value() >= max_conn { + return Ok(Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .body(Body::empty())?); + } + } + if hyper_tungstenite::is_upgrade_request(&request) { let web_socket_config = WebSocketConfig { max_write_buffer_size: 1024 * 1024, // 1 MB @@ -147,8 +157,15 @@ async fn handle_http_request( replied: false, }; - // Increment count of active websockets - let old_num_websockets = GLOBALS.num_clients.fetch_add(1, Ordering::SeqCst); + // Increment connection count + let old_num_websockets = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst); + + // Increment per-ip connection count + GLOBALS + .num_connections_per_ip + .entry(peer.ip()) + .and_modify(|count| *count += 1) + .or_insert(1); // we cheat somewhat and log these websocket open and close messages // as server messages @@ -197,7 +214,19 @@ async fn handle_http_request( } // Decrement count of active websockets - let old_num_websockets = GLOBALS.num_clients.fetch_sub(1, Ordering::SeqCst); + let old_num_websockets = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst); + + // Decrement per-ip connection count + match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) { + Some(mut refmut) => { + if *refmut.value_mut() > 0 { + *refmut.value_mut() -= 1; + } else { + unreachable!("The connection should be in the map") + } + } + None => unreachable!("The connection count should be greater than zero"), + }; // Update ip data (including ban time) // if GLOBALS.config.read().enable_ip_blocking {