From f00f6cccf4ea32d7bdd6c42e50d80a989cd48ec5 Mon Sep 17 00:00:00 2001 From: Mike Dilger Date: Sat, 22 Jun 2024 12:04:23 +1200 Subject: [PATCH] retire MaybeTlsStream, use a dyn trait instead --- src/bin/chorus.rs | 42 ++++++++++++--------- src/lib.rs | 9 ++++- src/tls.rs | 94 ----------------------------------------------- 3 files changed, 32 insertions(+), 113 deletions(-) diff --git a/src/bin/chorus.rs b/src/bin/chorus.rs index bf21ff5..52e2040 100644 --- a/src/bin/chorus.rs +++ b/src/bin/chorus.rs @@ -2,7 +2,7 @@ use chorus::config::{Config, FriendlyConfig}; use chorus::error::Error; use chorus::globals::GLOBALS; use chorus::ip::HashedPeer; -use chorus::tls::MaybeTlsStream; +use chorus::FullStream; use std::env; use std::fs::OpenOptions; use std::io::Read; @@ -104,27 +104,35 @@ async fn main() -> Result<(), Error> { } } - if let Some(tls_acceptor) = &maybe_tls_acceptor { - let tls_acceptor_clone = tls_acceptor.clone(); - tokio::spawn(async move { - match tls_acceptor_clone.accept(tcp_stream).await { - Err(e) => log::error!( - target: "Client", - "{}: {}", hashed_peer, e - ), - Ok(tls_stream) => { - if let Err(e) = chorus::serve(MaybeTlsStream::Rustls(tls_stream), hashed_peer).await { + let maybe_tls_acceptor_clone = maybe_tls_acceptor.clone(); + tokio::spawn(async move { + let stream: Box = match maybe_tls_acceptor_clone { + Some(tls_acceptor) => { + match tls_acceptor.accept(tcp_stream).await { + Ok(stream) => Box::new(stream), + Err(e) => { log::error!( target: "Client", - "{}: {}", hashed_peer, e + "{}: TLS accept: {}", hashed_peer, e ); + return; } } - } - }); - } else { - chorus::serve(MaybeTlsStream::Plain(tcp_stream), hashed_peer).await?; - } + }, + None => Box::new(tcp_stream) + }; + if let Err(e) = chorus::serve(stream, hashed_peer).await { + log::error!( + target: "Client", + "{}: {}", hashed_peer, e + ); + } + }); + + //Err(e) => log::error!( + //target: "Client", + //"{}: {}", hashed_peer, e + //), } }; } diff --git a/src/lib.rs b/src/lib.rs index 0f88885..ce21da3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,6 @@ use crate::error::{ChorusError, Error}; use crate::globals::GLOBALS; use crate::ip::{HashedIp, HashedPeer, IpData, SessionExit}; use crate::reply::NostrReply; -use crate::tls::MaybeTlsStream; use futures::{sink::SinkExt, stream::StreamExt}; use hyper::service::Service; use hyper::upgrade::Upgraded; @@ -35,13 +34,19 @@ use std::sync::atomic::Ordering; use std::task::{Context, Poll}; use std::time::Duration; use textnonce::TextNonce; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::time::Instant; +use tokio_rustls::server::TlsStream; use tungstenite::protocol::WebSocketConfig; use tungstenite::Message; +pub trait FullStream: AsyncRead + AsyncWrite + Unpin + Send {} +impl FullStream for TcpStream {} +impl FullStream for TlsStream {} + /// Serve a single network connection -pub async fn serve(stream: MaybeTlsStream, peer: HashedPeer) -> Result<(), Error> { +pub async fn serve(stream: Box, peer: HashedPeer) -> Result<(), Error> { // Serve the network stream with our http server and our HttpService let service = HttpService { peer }; diff --git a/src/tls.rs b/src/tls.rs index 5993f6c..440ab40 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,14 +1,9 @@ use crate::config::Config; use crate::error::{ChorusError, Error}; -use crate::globals::GLOBALS; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use std::fs::File; use std::io::BufReader; -use std::pin::Pin; -use std::sync::atomic::Ordering; use std::sync::Arc; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::{rustls, TlsAcceptor}; pub fn tls_acceptor(config: &Config) -> Result { @@ -38,92 +33,3 @@ pub fn tls_acceptor(config: &Config) -> Result { Ok(TlsAcceptor::from(Arc::new(tls_config))) } - -/// A stream that might be protected with TLS. -#[allow(clippy::large_enum_variant)] // not great though -#[derive(Debug)] -pub enum MaybeTlsStream { - /// Unencrypted socket stream. - Plain(S), - /// Encrypted socket stream using `rustls`. - Rustls(tokio_rustls::server::TlsStream), -} - -impl AsyncRead for MaybeTlsStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.get_mut() { - MaybeTlsStream::Plain(ref mut s) => { - // Count bytes for statistics - let pre = buf.filled().len(); - let result = Pin::new(s).poll_read(cx, buf); - let post = buf.filled().len(); - let count = post - pre; - if count > 0 { - let _ = GLOBALS - .bytes_inbound - .fetch_add(count as u64, Ordering::SeqCst); - } - result - } - MaybeTlsStream::Rustls(s) => { - // Count bytes for statistics - let pre = buf.filled().len(); - let result = Pin::new(s).poll_read(cx, buf); - let post = buf.filled().len(); - let count = post - pre; - if count > 0 { - let _ = GLOBALS - .bytes_inbound - .fetch_add(count as u64, Ordering::SeqCst); - } - result - } - } - } -} - -impl AsyncWrite for MaybeTlsStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.get_mut() { - MaybeTlsStream::Plain(ref mut s) => { - // Count bytes for statistics - let _ = GLOBALS - .bytes_outbound - .fetch_add(buf.len() as u64, Ordering::SeqCst); - Pin::new(s).poll_write(cx, buf) - } - MaybeTlsStream::Rustls(s) => { - // Count bytes for statistics - let _ = GLOBALS - .bytes_outbound - .fetch_add(buf.len() as u64, Ordering::SeqCst); - Pin::new(s).poll_write(cx, buf) - } - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx), - MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx), - } - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.get_mut() { - MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx), - MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx), - } - } -}