mirror of
https://github.com/mikedilger/chorus.git
synced 2026-06-02 07:02:09 +00:00
retire MaybeTlsStream, use a dyn trait instead
This commit is contained in:
parent
97a8a16999
commit
f00f6cccf4
@ -2,7 +2,7 @@ use chorus::config::{Config, FriendlyConfig};
|
||||
use chorus::error::Error;
|
||||
use chorus::globals::GLOBALS;
|
||||
use chorus::ip::HashedPeer;
|
||||
use chorus::tls::MaybeTlsStream;
|
||||
use chorus::FullStream;
|
||||
use std::env;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Read;
|
||||
@ -104,27 +104,35 @@ async fn main() -> Result<(), Error> {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tls_acceptor) = &maybe_tls_acceptor {
|
||||
let tls_acceptor_clone = tls_acceptor.clone();
|
||||
tokio::spawn(async move {
|
||||
match tls_acceptor_clone.accept(tcp_stream).await {
|
||||
Err(e) => log::error!(
|
||||
target: "Client",
|
||||
"{}: {}", hashed_peer, e
|
||||
),
|
||||
Ok(tls_stream) => {
|
||||
if let Err(e) = chorus::serve(MaybeTlsStream::Rustls(tls_stream), hashed_peer).await {
|
||||
let maybe_tls_acceptor_clone = maybe_tls_acceptor.clone();
|
||||
tokio::spawn(async move {
|
||||
let stream: Box<dyn FullStream> = match maybe_tls_acceptor_clone {
|
||||
Some(tls_acceptor) => {
|
||||
match tls_acceptor.accept(tcp_stream).await {
|
||||
Ok(stream) => Box::new(stream),
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
target: "Client",
|
||||
"{}: {}", hashed_peer, e
|
||||
"{}: TLS accept: {}", hashed_peer, e
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
chorus::serve(MaybeTlsStream::Plain(tcp_stream), hashed_peer).await?;
|
||||
}
|
||||
},
|
||||
None => Box::new(tcp_stream)
|
||||
};
|
||||
if let Err(e) = chorus::serve(stream, hashed_peer).await {
|
||||
log::error!(
|
||||
target: "Client",
|
||||
"{}: {}", hashed_peer, e
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
//Err(e) => log::error!(
|
||||
//target: "Client",
|
||||
//"{}: {}", hashed_peer, e
|
||||
//),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@ -12,7 +12,6 @@ use crate::error::{ChorusError, Error};
|
||||
use crate::globals::GLOBALS;
|
||||
use crate::ip::{HashedIp, HashedPeer, IpData, SessionExit};
|
||||
use crate::reply::NostrReply;
|
||||
use crate::tls::MaybeTlsStream;
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
use hyper::service::Service;
|
||||
use hyper::upgrade::Upgraded;
|
||||
@ -35,13 +34,19 @@ use std::sync::atomic::Ordering;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use textnonce::TextNonce;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Instant;
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tungstenite::protocol::WebSocketConfig;
|
||||
use tungstenite::Message;
|
||||
|
||||
pub trait FullStream: AsyncRead + AsyncWrite + Unpin + Send {}
|
||||
impl FullStream for TcpStream {}
|
||||
impl FullStream for TlsStream<TcpStream> {}
|
||||
|
||||
/// Serve a single network connection
|
||||
pub async fn serve(stream: MaybeTlsStream<TcpStream>, peer: HashedPeer) -> Result<(), Error> {
|
||||
pub async fn serve(stream: Box<dyn FullStream>, peer: HashedPeer) -> Result<(), Error> {
|
||||
// Serve the network stream with our http server and our HttpService
|
||||
let service = HttpService { peer };
|
||||
|
||||
|
||||
94
src/tls.rs
94
src/tls.rs
@ -1,14 +1,9 @@
|
||||
use crate::config::Config;
|
||||
use crate::error::{ChorusError, Error};
|
||||
use crate::globals::GLOBALS;
|
||||
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::{rustls, TlsAcceptor};
|
||||
|
||||
pub fn tls_acceptor(config: &Config) -> Result<TlsAcceptor, Error> {
|
||||
@ -38,92 +33,3 @@ pub fn tls_acceptor(config: &Config) -> Result<TlsAcceptor, Error> {
|
||||
|
||||
Ok(TlsAcceptor::from(Arc::new(tls_config)))
|
||||
}
|
||||
|
||||
/// A stream that might be protected with TLS.
|
||||
#[allow(clippy::large_enum_variant)] // not great though
|
||||
#[derive(Debug)]
|
||||
pub enum MaybeTlsStream<S> {
|
||||
/// Unencrypted socket stream.
|
||||
Plain(S),
|
||||
/// Encrypted socket stream using `rustls`.
|
||||
Rustls(tokio_rustls::server::TlsStream<S>),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::Plain(ref mut s) => {
|
||||
// Count bytes for statistics
|
||||
let pre = buf.filled().len();
|
||||
let result = Pin::new(s).poll_read(cx, buf);
|
||||
let post = buf.filled().len();
|
||||
let count = post - pre;
|
||||
if count > 0 {
|
||||
let _ = GLOBALS
|
||||
.bytes_inbound
|
||||
.fetch_add(count as u64, Ordering::SeqCst);
|
||||
}
|
||||
result
|
||||
}
|
||||
MaybeTlsStream::Rustls(s) => {
|
||||
// Count bytes for statistics
|
||||
let pre = buf.filled().len();
|
||||
let result = Pin::new(s).poll_read(cx, buf);
|
||||
let post = buf.filled().len();
|
||||
let count = post - pre;
|
||||
if count > 0 {
|
||||
let _ = GLOBALS
|
||||
.bytes_inbound
|
||||
.fetch_add(count as u64, Ordering::SeqCst);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::Plain(ref mut s) => {
|
||||
// Count bytes for statistics
|
||||
let _ = GLOBALS
|
||||
.bytes_outbound
|
||||
.fetch_add(buf.len() as u64, Ordering::SeqCst);
|
||||
Pin::new(s).poll_write(cx, buf)
|
||||
}
|
||||
MaybeTlsStream::Rustls(s) => {
|
||||
// Count bytes for statistics
|
||||
let _ = GLOBALS
|
||||
.bytes_outbound
|
||||
.fetch_add(buf.len() as u64, Ordering::SeqCst);
|
||||
Pin::new(s).poll_write(cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
|
||||
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
match self.get_mut() {
|
||||
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
|
||||
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user