mirror of
https://github.com/mikedilger/chorus.git
synced 2026-05-03 06:51:42 +00:00
Manipulate connection counts in a safer way (include HTTP connections too)
This commit is contained in:
parent
a76be24d68
commit
7ad5fb3e15
154
src/lib.rs
154
src/lib.rs
@ -85,53 +85,82 @@ impl Service<Request<Incoming>> for ChorusService {
|
||||
// This is called for each HTTP request made by the client
|
||||
// NOTE: it is not called for each websocket message once upgraded.
|
||||
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||
let mut hashed_peer = self.peer;
|
||||
|
||||
let failvalue =
|
||||
|c: ChorusError| -> Self::Future { Box::pin(futures::future::ready(Err(c.into()))) };
|
||||
|
||||
if GLOBALS.config.read().chorus_is_behind_a_proxy {
|
||||
// If chorus is behind a proxy that sets an "X-Real-Ip" header, we use
|
||||
// that ip address instead (otherwise their log file will just give the proxy IP
|
||||
// for every peer)
|
||||
//
|
||||
// This header must be found and be valid for us to proceed
|
||||
if let Some(rip) = req.headers().get("x-real-ip") {
|
||||
if let Ok(ripstr) = rip.to_str() {
|
||||
if let Ok(ipaddr) = ripstr.parse::<IpAddr>() {
|
||||
let hashed_ip = HashedIp::new(ipaddr);
|
||||
hashed_peer = HashedPeer::from_parts(hashed_ip, hashed_peer.port());
|
||||
} else {
|
||||
return failvalue(ChorusError::BadRealIpHeader(ripstr.to_owned()));
|
||||
}
|
||||
} else {
|
||||
return failvalue(ChorusError::BadRealIpHeaderCharacters);
|
||||
}
|
||||
} else {
|
||||
return failvalue(ChorusError::RealIpHeaderMissing);
|
||||
}
|
||||
|
||||
// Possibly IP block late (if behind a proxy)
|
||||
if GLOBALS.config.read().enable_ip_blocking {
|
||||
if let Ok(ip_data) =
|
||||
crate::get_ip_data(GLOBALS.store.get().unwrap(), hashed_peer.ip())
|
||||
{
|
||||
if ip_data.is_banned() {
|
||||
log::debug!(target: "Client",
|
||||
"{}: Blocking reconnection until {}",
|
||||
hashed_peer.ip(),
|
||||
ip_data.ban_until);
|
||||
return failvalue(ChorusError::BlockedIp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Box::pin(async move { handle_http_request(hashed_peer, req).await })
|
||||
let hashed_peer = self.peer;
|
||||
Box::pin(async move { handle_http_request_outer(hashed_peer, req).await })
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_http_request(
|
||||
async fn handle_http_request_outer(
|
||||
mut peer: HashedPeer,
|
||||
request: Request<Incoming>,
|
||||
) -> Result<Response<Full<Bytes>>, Error> {
|
||||
if GLOBALS.config.read().chorus_is_behind_a_proxy {
|
||||
// If chorus is behind a proxy that sets an "X-Real-Ip" header, we use
|
||||
// that ip address instead (otherwise their log file will just give the proxy IP
|
||||
// for every peer)
|
||||
//
|
||||
// This header must be found and be valid for us to proceed
|
||||
if let Some(rip) = request.headers().get("x-real-ip") {
|
||||
if let Ok(ripstr) = rip.to_str() {
|
||||
if let Ok(ipaddr) = ripstr.parse::<IpAddr>() {
|
||||
let hashed_ip = HashedIp::new(ipaddr);
|
||||
peer = HashedPeer::from_parts(hashed_ip, peer.port());
|
||||
} else {
|
||||
return Err(ChorusError::BadRealIpHeader(ripstr.to_owned()).into());
|
||||
}
|
||||
} else {
|
||||
return Err(ChorusError::BadRealIpHeaderCharacters.into());
|
||||
}
|
||||
} else {
|
||||
return Err(ChorusError::RealIpHeaderMissing.into());
|
||||
}
|
||||
|
||||
// Possibly IP block late (if behind a proxy)
|
||||
if GLOBALS.config.read().enable_ip_blocking {
|
||||
if let Ok(ip_data) = crate::get_ip_data(GLOBALS.store.get().unwrap(), peer.ip()) {
|
||||
if ip_data.is_banned() {
|
||||
log::debug!(target: "Client",
|
||||
"{}: Blocking reconnection until {}",
|
||||
peer.ip(),
|
||||
ip_data.ban_until);
|
||||
return Err(ChorusError::BlockedIp.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DO NOT return anything after this point or you could screw up the counts which must
|
||||
// go both up and then down.
|
||||
|
||||
// Increment connection counts
|
||||
let _ = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst);
|
||||
GLOBALS
|
||||
.num_connections_per_ip
|
||||
.entry(peer.ip())
|
||||
.and_modify(|count| *count += 1)
|
||||
.or_insert(1);
|
||||
|
||||
// Get the response, but do not throw error at this point
|
||||
let response = handle_http_request_inner(peer, request);
|
||||
|
||||
// Decrement connection counts
|
||||
match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) {
|
||||
Some(mut refmut) => {
|
||||
if *refmut.value_mut() > 0 {
|
||||
*refmut.value_mut() -= 1;
|
||||
} else {
|
||||
unreachable!("The connection should be in the map")
|
||||
}
|
||||
}
|
||||
None => unreachable!("The connection count should be greater than zero"),
|
||||
};
|
||||
let _ = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst);
|
||||
|
||||
// Now we return after the counts have gone up and down
|
||||
response.await
|
||||
}
|
||||
|
||||
async fn handle_http_request_inner(
|
||||
peer: HashedPeer,
|
||||
mut request: Request<Incoming>,
|
||||
) -> Result<Response<Full<Bytes>>, Error> {
|
||||
@ -145,6 +174,7 @@ async fn handle_http_request(
|
||||
None => "(no origin)".to_owned(),
|
||||
};
|
||||
|
||||
// Fail if too many requests on this IP
|
||||
let max_conn = GLOBALS.config.read().max_connections_per_ip;
|
||||
if let Some(cur) = GLOBALS.num_connections_per_ip.get(&peer.ip()) {
|
||||
if *cur.value() >= max_conn {
|
||||
@ -180,15 +210,9 @@ async fn handle_http_request(
|
||||
replied: false,
|
||||
};
|
||||
|
||||
// Increment connection count
|
||||
let old_num_websockets = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
// Increment per-ip connection count
|
||||
GLOBALS
|
||||
.num_connections_per_ip
|
||||
.entry(peer.ip())
|
||||
.and_modify(|count| *count += 1)
|
||||
.or_insert(1);
|
||||
// Everybody gets a ban on disconnect to prevent rapid reconnection
|
||||
let mut session_exit: SessionExit = SessionExit::Ok;
|
||||
let mut msg = "Closed";
|
||||
|
||||
// we cheat somewhat and log these websocket open and close messages
|
||||
// as server messages
|
||||
@ -196,15 +220,11 @@ async fn handle_http_request(
|
||||
target: "Server",
|
||||
"{}: TOTAL={}, New Connection: {}, {}",
|
||||
peer,
|
||||
old_num_websockets + 1,
|
||||
GLOBALS.num_connections.load(Ordering::Relaxed),
|
||||
origin,
|
||||
ua,
|
||||
);
|
||||
|
||||
// Everybody gets a ban on disconnect to prevent rapid reconnection
|
||||
let mut session_exit: SessionExit = SessionExit::Ok;
|
||||
let mut msg = "Closed";
|
||||
|
||||
// Handle the websocket
|
||||
if let Err(e) = ws_service.handle_websocket_stream().await {
|
||||
match e.inner {
|
||||
@ -253,21 +273,6 @@ async fn handle_http_request(
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement count of active websockets
|
||||
let old_num_websockets = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst);
|
||||
|
||||
// Decrement per-ip connection count
|
||||
match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) {
|
||||
Some(mut refmut) => {
|
||||
if *refmut.value_mut() > 0 {
|
||||
*refmut.value_mut() -= 1;
|
||||
} else {
|
||||
unreachable!("The connection should be in the map")
|
||||
}
|
||||
}
|
||||
None => unreachable!("The connection count should be greater than zero"),
|
||||
};
|
||||
|
||||
// Update ip data (including ban time)
|
||||
// if GLOBALS.config.read().enable_ip_blocking {
|
||||
let mut ban_seconds = 0;
|
||||
@ -284,7 +289,7 @@ async fn handle_http_request(
|
||||
target: "Server",
|
||||
"{}: TOTAL={}, {}, ban={}s",
|
||||
peer,
|
||||
old_num_websockets - 1,
|
||||
GLOBALS.num_connections.load(Ordering::Relaxed),
|
||||
msg,
|
||||
ban_seconds
|
||||
);
|
||||
@ -296,6 +301,7 @@ async fn handle_http_request(
|
||||
});
|
||||
Ok(response)
|
||||
} else {
|
||||
// We dont log normal HTTP requests nor do we ban them
|
||||
web::serve_http(peer, request).await
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user