diff --git a/src/main.rs b/src/main.rs index e3cc93a..90b12b4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ pub mod config; pub mod error; pub mod globals; pub mod store; +pub mod tls; pub mod types; pub mod web; diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..fbdf1f8 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,58 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A stream that might be protected with TLS. +#[allow(clippy::large_enum_variant)] // not great though +#[derive(Debug)] +pub enum MaybeTlsStream { + /// Unencrypted socket stream. + Plain(S), + /// Encrypted socket stream using `rustls`. + Rustls(tokio_rustls::client::TlsStream), +} + +impl AsyncRead for MaybeTlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf), + MaybeTlsStream::Rustls(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for MaybeTlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf), + MaybeTlsStream::Rustls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx), + MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx), + MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx), + } + } +}