diff --git a/src/lib.rs b/src/lib.rs index 781b461..cb4a364 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,53 +85,82 @@ impl Service> for ChorusService { // This is called for each HTTP request made by the client // NOTE: it is not called for each websocket message once upgraded. fn call(&self, req: Request) -> Self::Future { - let mut hashed_peer = self.peer; - - let failvalue = - |c: ChorusError| -> Self::Future { Box::pin(futures::future::ready(Err(c.into()))) }; - - if GLOBALS.config.read().chorus_is_behind_a_proxy { - // If chorus is behind a proxy that sets an "X-Real-Ip" header, we use - // that ip address instead (otherwise their log file will just give the proxy IP - // for every peer) - // - // This header must be found and be valid for us to proceed - if let Some(rip) = req.headers().get("x-real-ip") { - if let Ok(ripstr) = rip.to_str() { - if let Ok(ipaddr) = ripstr.parse::() { - let hashed_ip = HashedIp::new(ipaddr); - hashed_peer = HashedPeer::from_parts(hashed_ip, hashed_peer.port()); - } else { - return failvalue(ChorusError::BadRealIpHeader(ripstr.to_owned())); - } - } else { - return failvalue(ChorusError::BadRealIpHeaderCharacters); - } - } else { - return failvalue(ChorusError::RealIpHeaderMissing); - } - - // Possibly IP block late (if behind a proxy) - if GLOBALS.config.read().enable_ip_blocking { - if let Ok(ip_data) = - crate::get_ip_data(GLOBALS.store.get().unwrap(), hashed_peer.ip()) - { - if ip_data.is_banned() { - log::debug!(target: "Client", - "{}: Blocking reconnection until {}", - hashed_peer.ip(), - ip_data.ban_until); - return failvalue(ChorusError::BlockedIp); - } - } - } - } - - Box::pin(async move { handle_http_request(hashed_peer, req).await }) + let hashed_peer = self.peer; + Box::pin(async move { handle_http_request_outer(hashed_peer, req).await }) } } -async fn handle_http_request( +async fn handle_http_request_outer( + mut peer: HashedPeer, + request: Request, +) -> Result>, Error> { + if GLOBALS.config.read().chorus_is_behind_a_proxy { + // If chorus is behind a proxy that sets an "X-Real-Ip" header, we use + // that ip address instead (otherwise their log file will just give the proxy IP + // for every peer) + // + // This header must be found and be valid for us to proceed + if let Some(rip) = request.headers().get("x-real-ip") { + if let Ok(ripstr) = rip.to_str() { + if let Ok(ipaddr) = ripstr.parse::() { + let hashed_ip = HashedIp::new(ipaddr); + peer = HashedPeer::from_parts(hashed_ip, peer.port()); + } else { + return Err(ChorusError::BadRealIpHeader(ripstr.to_owned()).into()); + } + } else { + return Err(ChorusError::BadRealIpHeaderCharacters.into()); + } + } else { + return Err(ChorusError::RealIpHeaderMissing.into()); + } + + // Possibly IP block late (if behind a proxy) + if GLOBALS.config.read().enable_ip_blocking { + if let Ok(ip_data) = crate::get_ip_data(GLOBALS.store.get().unwrap(), peer.ip()) { + if ip_data.is_banned() { + log::debug!(target: "Client", + "{}: Blocking reconnection until {}", + peer.ip(), + ip_data.ban_until); + return Err(ChorusError::BlockedIp.into()); + } + } + } + } + + // DO NOT return anything after this point or you could screw up the counts which must + // go both up and then down. + + // Increment connection counts + let _ = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst); + GLOBALS + .num_connections_per_ip + .entry(peer.ip()) + .and_modify(|count| *count += 1) + .or_insert(1); + + // Get the response, but do not throw error at this point + let response = handle_http_request_inner(peer, request); + + // Decrement connection counts + match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) { + Some(mut refmut) => { + if *refmut.value_mut() > 0 { + *refmut.value_mut() -= 1; + } else { + unreachable!("The connection should be in the map") + } + } + None => unreachable!("The connection count should be greater than zero"), + }; + let _ = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst); + + // Now we return after the counts have gone up and down + response.await +} + +async fn handle_http_request_inner( peer: HashedPeer, mut request: Request, ) -> Result>, Error> { @@ -145,6 +174,7 @@ async fn handle_http_request( None => "(no origin)".to_owned(), }; + // Fail if too many requests on this IP let max_conn = GLOBALS.config.read().max_connections_per_ip; if let Some(cur) = GLOBALS.num_connections_per_ip.get(&peer.ip()) { if *cur.value() >= max_conn { @@ -180,15 +210,9 @@ async fn handle_http_request( replied: false, }; - // Increment connection count - let old_num_websockets = GLOBALS.num_connections.fetch_add(1, Ordering::SeqCst); - - // Increment per-ip connection count - GLOBALS - .num_connections_per_ip - .entry(peer.ip()) - .and_modify(|count| *count += 1) - .or_insert(1); + // Everybody gets a ban on disconnect to prevent rapid reconnection + let mut session_exit: SessionExit = SessionExit::Ok; + let mut msg = "Closed"; // we cheat somewhat and log these websocket open and close messages // as server messages @@ -196,15 +220,11 @@ async fn handle_http_request( target: "Server", "{}: TOTAL={}, New Connection: {}, {}", peer, - old_num_websockets + 1, + GLOBALS.num_connections.load(Ordering::Relaxed), origin, ua, ); - // Everybody gets a ban on disconnect to prevent rapid reconnection - let mut session_exit: SessionExit = SessionExit::Ok; - let mut msg = "Closed"; - // Handle the websocket if let Err(e) = ws_service.handle_websocket_stream().await { match e.inner { @@ -253,21 +273,6 @@ async fn handle_http_request( } } - // Decrement count of active websockets - let old_num_websockets = GLOBALS.num_connections.fetch_sub(1, Ordering::SeqCst); - - // Decrement per-ip connection count - match GLOBALS.num_connections_per_ip.get_mut(&peer.ip()) { - Some(mut refmut) => { - if *refmut.value_mut() > 0 { - *refmut.value_mut() -= 1; - } else { - unreachable!("The connection should be in the map") - } - } - None => unreachable!("The connection count should be greater than zero"), - }; - // Update ip data (including ban time) // if GLOBALS.config.read().enable_ip_blocking { let mut ban_seconds = 0; @@ -284,7 +289,7 @@ async fn handle_http_request( target: "Server", "{}: TOTAL={}, {}, ban={}s", peer, - old_num_websockets - 1, + GLOBALS.num_connections.load(Ordering::Relaxed), msg, ban_seconds ); @@ -296,6 +301,7 @@ async fn handle_http_request( }); Ok(response) } else { + // We dont log normal HTTP requests nor do we ban them web::serve_http(peer, request).await } }