diff --git a/src/main.rs b/src/main.rs index c464b3b..5723ca2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,15 +12,14 @@ use crate::config::Config; use crate::error::Error; use crate::globals::GLOBALS; use crate::store::Store; +use crate::tls::MaybeTlsStream; use hyper::{Body, Request, Response}; -use rustls::{Certificate, PrivateKey}; use std::env; use std::error::Error as StdError; -use std::fs::{File, OpenOptions}; -use std::io::{BufReader, Read}; -use std::sync::Arc; -use tokio::net::TcpListener; -use tokio_rustls::{rustls, TlsAcceptor}; +use std::fs::OpenOptions; +use std::io::Read; +use std::net::SocketAddr; +use tokio::net::{TcpListener, TcpStream}; #[tokio::main] async fn main() -> Result<(), Error> { @@ -46,32 +45,12 @@ async fn main() -> Result<(), Error> { let _ = GLOBALS.store.set(store); // TLS setup - let tls_acceptor = { - let certs: Vec = - rustls_pemfile::certs(&mut BufReader::new(File::open(&config.certchain_pem_path)?))? - .drain(..) - .map(Certificate) - .collect(); - - let mut keys: Vec = rustls_pemfile::pkcs8_private_keys(&mut BufReader::new( - File::open(&config.key_pem_path)?, - ))? - .drain(..) - .rev() - .map(PrivateKey) - .collect(); - - let key = match keys.pop() { - Some(k) => k, - None => return Err(Error::NoPrivateKey), - }; - - let tls_config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs, key)?; - - TlsAcceptor::from(Arc::new(tls_config)) + let maybe_tls_acceptor = if config.use_tls { + log::info!("Using TLS"); + Some(tls::tls_acceptor(&config)?) + } else { + log::info!("Not using TLS"); + None }; // Bind listener to port @@ -85,39 +64,51 @@ async fn main() -> Result<(), Error> { loop { let (tcp_stream, peer_addr) = listener.accept().await?; - let acceptor = tls_acceptor.clone(); - tokio::spawn(async move { - match acceptor.accept(tcp_stream).await { - Err(e) => log::error!("{}", e), - Ok(tls_stream) => { - let connection = GLOBALS - .http_server - .serve_connection(tls_stream, hyper::service::service_fn(handle_request)); - tokio::spawn(async move { - // If our service exits with an error, log the error - if let Err(he) = connection.await { - if let Some(src) = he.source() { - if &*format!("{}", src) - == "Transport endpoint is not connected (os error 107)" - { - // do nothing - } else { - // Print in detail - log::info!("{:?}", src); - } - } else { - // Print in less detail - let e: Error = he.into(); - log::info!("{}", e); - } + if let Some(tls_acceptor) = &maybe_tls_acceptor { + let tls_acceptor_clone = tls_acceptor.clone(); + tokio::spawn(async move { + match tls_acceptor_clone.accept(tcp_stream).await { + Err(e) => log::error!("{}", e), + Ok(tls_stream) => { + if let Err(e) = serve(MaybeTlsStream::Rustls(tls_stream), peer_addr).await { + log::error!("{}", e); } - }); + } } - } - }); + }); + } else { + serve(MaybeTlsStream::Plain(tcp_stream), peer_addr).await?; + } } } +// Serve a single network connection +async fn serve(stream: MaybeTlsStream, peer_addr: SocketAddr) -> Result<(), Error> { + let connection = GLOBALS + .http_server + .serve_connection(stream, hyper::service::service_fn(handle_request)); + + tokio::spawn(async move { + // If our service exits with an error, log the error + if let Err(he) = connection.await { + if let Some(src) = he.source() { + if &*format!("{}", src) == "Transport endpoint is not connected (os error 107)" { + // do nothing + } else { + // Print in detail + log::error!("{:?}", src); + } + } else { + // Print in less detail + let e: Error = he.into(); + log::error!("{}", e); + } + } + }); + + Ok(()) +} + async fn handle_request(_request: Request) -> Result, Error> { web::serve_http().await } diff --git a/src/tls.rs b/src/tls.rs index fbdf1f8..1f04084 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,8 +1,40 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, -}; +use crate::config::Config; +use crate::error::Error; +use rustls::{Certificate, PrivateKey}; +use std::fs::File; +use std::io::BufReader; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::{rustls, TlsAcceptor}; + +pub fn tls_acceptor(config: &Config) -> Result { + let certs: Vec = + rustls_pemfile::certs(&mut BufReader::new(File::open(&config.certchain_pem_path)?))? + .drain(..) + .map(Certificate) + .collect(); + + let mut keys: Vec = + rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(File::open(&config.key_pem_path)?))? + .drain(..) + .rev() + .map(PrivateKey) + .collect(); + + let key = match keys.pop() { + Some(k) => k, + None => return Err(Error::NoPrivateKey), + }; + + let tls_config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, key)?; + + Ok(TlsAcceptor::from(Arc::new(tls_config))) +} /// A stream that might be protected with TLS. #[allow(clippy::large_enum_variant)] // not great though @@ -11,7 +43,7 @@ pub enum MaybeTlsStream { /// Unencrypted socket stream. Plain(S), /// Encrypted socket stream using `rustls`. - Rustls(tokio_rustls::client::TlsStream), + Rustls(tokio_rustls::server::TlsStream), } impl AsyncRead for MaybeTlsStream {