TLS optional

This commit is contained in:
Mike Dilger 2023-10-26 20:16:33 +13:00
parent d2191f9699
commit 6cb037986e
2 changed files with 88 additions and 65 deletions

View File

@ -12,15 +12,14 @@ use crate::config::Config;
use crate::error::Error;
use crate::globals::GLOBALS;
use crate::store::Store;
use crate::tls::MaybeTlsStream;
use hyper::{Body, Request, Response};
use rustls::{Certificate, PrivateKey};
use std::env;
use std::error::Error as StdError;
use std::fs::{File, OpenOptions};
use std::io::{BufReader, Read};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::{rustls, TlsAcceptor};
use std::fs::OpenOptions;
use std::io::Read;
use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream};
#[tokio::main]
async fn main() -> Result<(), Error> {
@ -46,32 +45,12 @@ async fn main() -> Result<(), Error> {
let _ = GLOBALS.store.set(store);
// TLS setup
let tls_acceptor = {
let certs: Vec<Certificate> =
rustls_pemfile::certs(&mut BufReader::new(File::open(&config.certchain_pem_path)?))?
.drain(..)
.map(Certificate)
.collect();
let mut keys: Vec<PrivateKey> = rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(
File::open(&config.key_pem_path)?,
))?
.drain(..)
.rev()
.map(PrivateKey)
.collect();
let key = match keys.pop() {
Some(k) => k,
None => return Err(Error::NoPrivateKey),
};
let tls_config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?;
TlsAcceptor::from(Arc::new(tls_config))
let maybe_tls_acceptor = if config.use_tls {
log::info!("Using TLS");
Some(tls::tls_acceptor(&config)?)
} else {
log::info!("Not using TLS");
None
};
// Bind listener to port
@ -85,39 +64,51 @@ async fn main() -> Result<(), Error> {
loop {
let (tcp_stream, peer_addr) = listener.accept().await?;
let acceptor = tls_acceptor.clone();
tokio::spawn(async move {
match acceptor.accept(tcp_stream).await {
Err(e) => log::error!("{}", e),
Ok(tls_stream) => {
let connection = GLOBALS
.http_server
.serve_connection(tls_stream, hyper::service::service_fn(handle_request));
tokio::spawn(async move {
// If our service exits with an error, log the error
if let Err(he) = connection.await {
if let Some(src) = he.source() {
if &*format!("{}", src)
== "Transport endpoint is not connected (os error 107)"
{
// do nothing
} else {
// Print in detail
log::info!("{:?}", src);
}
} else {
// Print in less detail
let e: Error = he.into();
log::info!("{}", e);
}
if let Some(tls_acceptor) = &maybe_tls_acceptor {
let tls_acceptor_clone = tls_acceptor.clone();
tokio::spawn(async move {
match tls_acceptor_clone.accept(tcp_stream).await {
Err(e) => log::error!("{}", e),
Ok(tls_stream) => {
if let Err(e) = serve(MaybeTlsStream::Rustls(tls_stream), peer_addr).await {
log::error!("{}", e);
}
});
}
}
}
});
});
} else {
serve(MaybeTlsStream::Plain(tcp_stream), peer_addr).await?;
}
}
}
// Serve a single network connection
async fn serve(stream: MaybeTlsStream<TcpStream>, peer_addr: SocketAddr) -> Result<(), Error> {
let connection = GLOBALS
.http_server
.serve_connection(stream, hyper::service::service_fn(handle_request));
tokio::spawn(async move {
// If our service exits with an error, log the error
if let Err(he) = connection.await {
if let Some(src) = he.source() {
if &*format!("{}", src) == "Transport endpoint is not connected (os error 107)" {
// do nothing
} else {
// Print in detail
log::error!("{:?}", src);
}
} else {
// Print in less detail
let e: Error = he.into();
log::error!("{}", e);
}
}
});
Ok(())
}
async fn handle_request(_request: Request<Body>) -> Result<Response<Body>, Error> {
web::serve_http().await
}

View File

@ -1,8 +1,40 @@
use std::{
pin::Pin,
task::{Context, Poll},
};
use crate::config::Config;
use crate::error::Error;
use rustls::{Certificate, PrivateKey};
use std::fs::File;
use std::io::BufReader;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::{rustls, TlsAcceptor};
pub fn tls_acceptor(config: &Config) -> Result<TlsAcceptor, Error> {
let certs: Vec<Certificate> =
rustls_pemfile::certs(&mut BufReader::new(File::open(&config.certchain_pem_path)?))?
.drain(..)
.map(Certificate)
.collect();
let mut keys: Vec<PrivateKey> =
rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(File::open(&config.key_pem_path)?))?
.drain(..)
.rev()
.map(PrivateKey)
.collect();
let key = match keys.pop() {
Some(k) => k,
None => return Err(Error::NoPrivateKey),
};
let tls_config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)?;
Ok(TlsAcceptor::from(Arc::new(tls_config)))
}
/// A stream that might be protected with TLS.
#[allow(clippy::large_enum_variant)] // not great though
@ -11,7 +43,7 @@ pub enum MaybeTlsStream<S> {
/// Unencrypted socket stream.
Plain(S),
/// Encrypted socket stream using `rustls`.
Rustls(tokio_rustls::client::TlsStream<S>),
Rustls(tokio_rustls::server::TlsStream<S>),
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {