diff --git a/src/error.rs b/src/error.rs index 59be098..fd05261 100644 --- a/src/error.rs +++ b/src/error.rs @@ -146,6 +146,9 @@ pub enum ChorusError { // Serde JSON SerdeJson(serde_json::Error), + // Shutting Down + ShuttingDown, + // Signal - Not Blossom Request SignalNotBlossom, @@ -222,6 +225,7 @@ impl std::fmt::Display for ChorusError { ChorusError::Rustls(e) => write!(f, "{e}"), ChorusError::Scraper => write!(f, "Filter is underspecified. Scrapers are not allowed"), ChorusError::SerdeJson(e) => write!(f, "{e}"), + ChorusError::ShuttingDown => write!(f, "Shutting down"), ChorusError::SignalNotBlossom => write!(f, "internal-signal-not-blossom"), ChorusError::Speedy(e) => write!(f, "{e}"), ChorusError::TimedOut => write!(f, "Timed out"), @@ -318,6 +322,7 @@ impl ChorusError { ChorusError::Rustls(_) => 0.0, ChorusError::Scraper => 0.4, ChorusError::SerdeJson(_) => 0.0, + ChorusError::ShuttingDown => 0.0, ChorusError::SignalNotBlossom => 0.0, ChorusError::Speedy(_) => 0.0, ChorusError::TimedOut => 0.1, diff --git a/src/lib.rs b/src/lib.rs index e3bb337..4782333 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,7 @@ use hyper_util::rt::TokioIo; use pocket_db::Store; use pocket_types::{Id, OwnedFilter, Pubkey}; use speedy::{Readable, Writable}; +use std::borrow::Cow; use std::collections::HashMap; use std::error::Error as StdError; use std::fs::OpenOptions; @@ -377,6 +378,30 @@ impl WebSocketService { Ok(self.websocket.send(m).await?) } + async fn wsclose(&mut self, error: Error) -> Result<(), Error> { + use tungstenite::protocol::frame::coding::CloseCode; + use tungstenite::protocol::frame::CloseFrame; + + let (code, reason) = match &error.inner { + ChorusError::TimedOut => (CloseCode::Policy, Cow::Borrowed("timed out")), + ChorusError::ShuttingDown => (CloseCode::Restart, Cow::Borrowed("restarting")), + ChorusError::BannedUser | ChorusError::BlockedIp => { + (CloseCode::Policy, Cow::Borrowed("banned")) + } + e => (CloseCode::Error, Cow::Owned(format!("{}", e))), + }; + + let close_frame = CloseFrame { code, reason }; + + // NOTE: This is the same as sending Message::Close(..) + self.websocket.close(Some(close_frame)).await?; + + // Drive to completion + while let Some(_) = self.websocket.next().await {} + + Err(error) + } + async fn handle_websocket_stream(&mut self) -> Result<(), Error> { // Subscribe to the shutting down channel let mut shutting_down = GLOBALS.shutting_down.subscribe(); @@ -403,8 +428,7 @@ impl WebSocketService { if self.subscriptions.is_empty() { // And they are idle for timeout_seconds with no subscriptions if last_message_at + Duration::from_secs(timeout_seconds) < instant { - self.send(Message::Close(None)).await?; - return Err(ChorusError::TimedOut.into()); + self.wsclose(ChorusError::TimedOut.into()).await?; } } } @@ -414,10 +438,7 @@ impl WebSocketService { Some(message) => { let message = message?; if let Err(e) = self.handle_websocket_message(message).await { - if let Err(e) = self.websocket.close(None).await { - log::info!(target: "Client", "{}: Err on websocket close: {e}", self.peer); - } - return Err(e); + self.wsclose(e).await?; } }, None => break, // the websocket is closed @@ -428,9 +449,7 @@ impl WebSocketService { self.handle_new_event(offset).await?; }, _r = shutting_down.changed() => { - // Shutdown the websocket gracefully - self.send(Message::Close(None)).await?; - break; + self.wsclose(ChorusError::ShuttingDown.into()).await?; }, } }