Enforce throtting on output

This commit is contained in:
Mike Dilger 2024-07-14 12:34:18 +12:00
parent 7a1b8e7ec5
commit 1f1a74e442

View File

@ -326,9 +326,14 @@ impl WebSocketService {
async fn send(&mut self, m: Message) -> Result<(), Error> { async fn send(&mut self, m: Message) -> Result<(), Error> {
// Throttling: we consume burst tokens, but we do not throttle on output // Throttling: we consume burst tokens, but we do not throttle on output
if m.len() > self.burst_tokens { if m.len() > self.burst_tokens {
self.burst_tokens = 0; log::info!(target: "Client", "{}: Rate limited exceeded", self.peer);
let reply = NostrReply::Notice("Rate limit exceeded.".into());
self.websocket.send(Message::text(reply.as_json())).await?;
let error = ChorusError::RateLimitExceeded;
self.error_punishment += error.punishment();
return Err(error.into());
} else { } else {
self.burst_tokens = self.burst_tokens - m.len(); self.burst_tokens -= m.len();
} }
self.replied = true; self.replied = true;
@ -443,7 +448,7 @@ impl WebSocketService {
// Grant new tokens // Grant new tokens
let new_tokens = throttling_bytes_per_second * elapsed.as_millis() as usize / 1_000; let new_tokens = throttling_bytes_per_second * elapsed.as_millis() as usize / 1_000;
self.burst_tokens = self.burst_tokens + new_tokens; self.burst_tokens += new_tokens;
// Cap tokens to a maximum // Cap tokens to a maximum
if self.burst_tokens > throttling_burst { if self.burst_tokens > throttling_burst {
@ -454,12 +459,12 @@ impl WebSocketService {
if message.len() > self.burst_tokens { if message.len() > self.burst_tokens {
log::info!(target: "Client", "{}: Rate limited exceeded", self.peer); log::info!(target: "Client", "{}: Rate limited exceeded", self.peer);
let reply = NostrReply::Notice("Rate limit exceeded.".into()); let reply = NostrReply::Notice("Rate limit exceeded.".into());
self.send(Message::text(reply.as_json())).await?; self.websocket.send(Message::text(reply.as_json())).await?;
let error = ChorusError::RateLimitExceeded; let error = ChorusError::RateLimitExceeded;
self.error_punishment += error.punishment(); self.error_punishment += error.punishment();
return Err(error.into()); return Err(error.into());
} else { } else {
self.burst_tokens = self.burst_tokens - message.len(); self.burst_tokens -= message.len();
} }
} }