max_connections_per_ip setting

This commit is contained in:
Mike Dilger 2024-05-25 07:40:14 +12:00
parent 0d2d562bb0
commit 2cc9902f23
9 changed files with 65 additions and 13 deletions

View File

@ -244,4 +244,11 @@ minimum_ban_seconds = 1
#
# Default is 60
#
timeout_seconds = 60
timeout_seconds = 60
# Maximum number of websocket connections per IP address
#
# Default is 2
#
max_connections_per_ip = 2

View File

@ -214,3 +214,9 @@ Default is 1
Number of seconds beyond which chorus times out a client that has no open subscriptions.
Default is 60
### max_connections_per_ip
Maximum number of websocket connections per IP address
Default is 2

View File

@ -26,3 +26,4 @@ client_log_level = "Warn"
enable_ip_blocking = true
minimum_ban_seconds = 1
timeout_seconds = 60
max_connections_per_ip = 2

View File

@ -136,15 +136,15 @@ async fn main() -> Result<(), Error> {
let _ = GLOBALS.shutting_down.send(true);
// Wait for active websockets to shutdown gracefully
let mut num_clients = GLOBALS.num_clients.load(Ordering::Relaxed);
if num_clients != 0 {
log::info!(target: "Server", "Waiting for {num_clients} websockets to shutdown...");
let mut num_connections = GLOBALS.num_connections.load(Ordering::Relaxed);
if num_connections != 0 {
log::info!(target: "Server", "Waiting for {num_connections} websockets to shutdown...");
// We will check if all clients have shutdown every 25ms
let interval = tokio::time::interval(Duration::from_millis(25));
tokio::pin!(interval);
while num_clients != 0 {
while num_connections != 0 {
// If we get another shutdown signal, stop waiting for websockets
tokio::select! {
v = interrupt_signal.recv() => if v.is_some() {
@ -157,7 +157,7 @@ async fn main() -> Result<(), Error> {
break;
},
_instant = interval.tick() => {
num_clients = GLOBALS.num_clients.load(Ordering::Relaxed);
num_connections = GLOBALS.num_connections.load(Ordering::Relaxed);
continue;
}
}

View File

@ -17,7 +17,7 @@ fn main() -> Result<(), Error> {
println!("Chorus must NOT be running when you do this.");
println!("Proceed? (break out with ^C, or press <ENTER> to proceed)");
let stdin = std::io::stdin();
let _ = stdin.lock().lines().next().unwrap().unwrap();
let _ = stdin.lock().lines().next().unwrap().unwrap();
let store = chorus::setup_store(&config)?;
let pre_stats = store.stats()?;

View File

@ -33,6 +33,7 @@ pub struct FriendlyConfig {
pub enable_ip_blocking: bool,
pub minimum_ban_seconds: u64,
pub timeout_seconds: u64,
pub max_connections_per_ip: usize,
}
impl Default for FriendlyConfig {
@ -64,6 +65,7 @@ impl Default for FriendlyConfig {
enable_ip_blocking: true,
minimum_ban_seconds: 1,
timeout_seconds: 60,
max_connections_per_ip: 2,
}
}
}
@ -97,6 +99,7 @@ impl FriendlyConfig {
enable_ip_blocking,
minimum_ban_seconds,
timeout_seconds,
max_connections_per_ip,
} = self;
let mut public_key: Option<Pubkey> = None;
@ -146,6 +149,7 @@ impl FriendlyConfig {
enable_ip_blocking,
minimum_ban_seconds,
timeout_seconds,
max_connections_per_ip,
})
}
}
@ -179,6 +183,7 @@ pub struct Config {
pub enable_ip_blocking: bool,
pub minimum_ban_seconds: u64,
pub timeout_seconds: u64,
pub max_connections_per_ip: usize,
}
impl Default for Config {

View File

@ -1,4 +1,6 @@
use crate::config::Config;
use crate::ip::HashedIp;
use dashmap::DashMap;
use hyper::server::conn::Http;
use lazy_static::lazy_static;
use parking_lot::RwLock;
@ -24,7 +26,8 @@ pub struct Globals {
/// that subscription.
pub new_events: BroadcastSender<u64>,
pub num_clients: AtomicUsize,
pub num_connections: AtomicUsize,
pub num_connections_per_ip: DashMap<HashedIp, usize>,
pub shutting_down: WatchSender<bool>,
}
@ -46,7 +49,8 @@ lazy_static! {
http_server,
rid: OnceLock::new(),
new_events,
num_clients: AtomicUsize::new(0),
num_connections: AtomicUsize::new(0),
num_connections_per_ip: DashMap::new(),
shutting_down,
}
};

View File

@ -2,7 +2,7 @@ use pocket_types::Time;
use speedy::{Readable, Writable};
use std::net::{IpAddr, SocketAddr};
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct HashedIp(pub [u8; 20], bool);
impl std::fmt::Display for HashedIp {

View File

@ -16,6 +16,7 @@ use crate::tls::MaybeTlsStream;
use futures::{sink::SinkExt, stream::StreamExt};
use hyper::service::Service;
use hyper::upgrade::Upgraded;
use hyper::StatusCode;
use hyper::{Body, Request, Response};
use hyper_tungstenite::tungstenite;
use hyper_tungstenite::WebSocketStream;
@ -121,6 +122,15 @@ async fn handle_http_request(
None => "(no origin)".to_owned(),
};
let max_conn = GLOBALS.config.read().max_connections_per_ip;
if let Some(cur) = GLOBALS.num_connections_per_ip.get(&peer.ip()) {
if *cur.value() >= max_conn {
return Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(Body::empty())?);
}
}
if hyper_tungstenite::is_upgrade_request(&request) {
let web_socket_config = WebSocketConfig {
max_write_buffer_size: 1024 * 1024, // 1 MB
@ -147,8 +157,15 @@ async fn handle_http_request(
replied: false,
};
// Increment count of active websockets
let old_num_websockets = GLOBALS.num_clients.fetch_add(1, Ordering::SeqCst);
// Increment connection count
let old_num_websockets = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst);
// Increment per-ip connection count
GLOBALS
.num_connections_per_ip
.entry(peer.ip())
.and_modify(|count| *count += 1)
.or_insert(1);
// we cheat somewhat and log these websocket open and close messages
// as server messages
@ -197,7 +214,19 @@ async fn handle_http_request(
}
// Decrement count of active websockets
let old_num_websockets = GLOBALS.num_clients.fetch_sub(1, Ordering::SeqCst);
let old_num_websockets = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst);
// Decrement per-ip connection count
match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) {
Some(mut refmut) => {
if *refmut.value_mut() > 0 {
*refmut.value_mut() -= 1;
} else {
unreachable!("The connection should be in the map")
}
}
None => unreachable!("The connection count should be greater than zero"),
};
// Update ip data (including ban time)
// if GLOBALS.config.read().enable_ip_blocking {