retire MaybeTlsStream, use a dyn trait instead

This commit is contained in:
Mike Dilger 2024-06-22 12:04:23 +12:00
parent 97a8a16999
commit f00f6cccf4
3 changed files with 32 additions and 113 deletions

View File

@ -2,7 +2,7 @@ use chorus::config::{Config, FriendlyConfig};
use chorus::error::Error;
use chorus::globals::GLOBALS;
use chorus::ip::HashedPeer;
use chorus::tls::MaybeTlsStream;
use chorus::FullStream;
use std::env;
use std::fs::OpenOptions;
use std::io::Read;
@ -104,27 +104,35 @@ async fn main() -> Result<(), Error> {
}
}
if let Some(tls_acceptor) = &maybe_tls_acceptor {
let tls_acceptor_clone = tls_acceptor.clone();
tokio::spawn(async move {
match tls_acceptor_clone.accept(tcp_stream).await {
Err(e) => log::error!(
target: "Client",
"{}: {}", hashed_peer, e
),
Ok(tls_stream) => {
if let Err(e) = chorus::serve(MaybeTlsStream::Rustls(tls_stream), hashed_peer).await {
let maybe_tls_acceptor_clone = maybe_tls_acceptor.clone();
tokio::spawn(async move {
let stream: Box<dyn FullStream> = match maybe_tls_acceptor_clone {
Some(tls_acceptor) => {
match tls_acceptor.accept(tcp_stream).await {
Ok(stream) => Box::new(stream),
Err(e) => {
log::error!(
target: "Client",
"{}: {}", hashed_peer, e
"{}: TLS accept: {}", hashed_peer, e
);
return;
}
}
}
});
} else {
chorus::serve(MaybeTlsStream::Plain(tcp_stream), hashed_peer).await?;
}
},
None => Box::new(tcp_stream)
};
if let Err(e) = chorus::serve(stream, hashed_peer).await {
log::error!(
target: "Client",
"{}: {}", hashed_peer, e
);
}
});
//Err(e) => log::error!(
//target: "Client",
//"{}: {}", hashed_peer, e
//),
}
};
}

View File

@ -12,7 +12,6 @@ use crate::error::{ChorusError, Error};
use crate::globals::GLOBALS;
use crate::ip::{HashedIp, HashedPeer, IpData, SessionExit};
use crate::reply::NostrReply;
use crate::tls::MaybeTlsStream;
use futures::{sink::SinkExt, stream::StreamExt};
use hyper::service::Service;
use hyper::upgrade::Upgraded;
@ -35,13 +34,19 @@ use std::sync::atomic::Ordering;
use std::task::{Context, Poll};
use std::time::Duration;
use textnonce::TextNonce;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_rustls::server::TlsStream;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Message;
pub trait FullStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl FullStream for TcpStream {}
impl FullStream for TlsStream<TcpStream> {}
/// Serve a single network connection
pub async fn serve(stream: MaybeTlsStream<TcpStream>, peer: HashedPeer) -> Result<(), Error> {
pub async fn serve(stream: Box<dyn FullStream>, peer: HashedPeer) -> Result<(), Error> {
// Serve the network stream with our http server and our HttpService
let service = HttpService { peer };

View File

@ -1,14 +1,9 @@
use crate::config::Config;
use crate::error::{ChorusError, Error};
use crate::globals::GLOBALS;
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use std::fs::File;
use std::io::BufReader;
use std::pin::Pin;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::{rustls, TlsAcceptor};
pub fn tls_acceptor(config: &Config) -> Result<TlsAcceptor, Error> {
@ -38,92 +33,3 @@ pub fn tls_acceptor(config: &Config) -> Result<TlsAcceptor, Error> {
Ok(TlsAcceptor::from(Arc::new(tls_config)))
}
/// A stream that might be protected with TLS.
#[allow(clippy::large_enum_variant)] // not great though
#[derive(Debug)]
pub enum MaybeTlsStream<S> {
/// Unencrypted socket stream.
Plain(S),
/// Encrypted socket stream using `rustls`.
Rustls(tokio_rustls::server::TlsStream<S>),
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => {
// Count bytes for statistics
let pre = buf.filled().len();
let result = Pin::new(s).poll_read(cx, buf);
let post = buf.filled().len();
let count = post - pre;
if count > 0 {
let _ = GLOBALS
.bytes_inbound
.fetch_add(count as u64, Ordering::SeqCst);
}
result
}
MaybeTlsStream::Rustls(s) => {
// Count bytes for statistics
let pre = buf.filled().len();
let result = Pin::new(s).poll_read(cx, buf);
let post = buf.filled().len();
let count = post - pre;
if count > 0 {
let _ = GLOBALS
.bytes_inbound
.fetch_add(count as u64, Ordering::SeqCst);
}
result
}
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => {
// Count bytes for statistics
let _ = GLOBALS
.bytes_outbound
.fetch_add(buf.len() as u64, Ordering::SeqCst);
Pin::new(s).poll_write(cx, buf)
}
MaybeTlsStream::Rustls(s) => {
// Count bytes for statistics
let _ = GLOBALS
.bytes_outbound
.fetch_add(buf.len() as u64, Ordering::SeqCst);
Pin::new(s).poll_write(cx, buf)
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}