mirror of
https://github.com/mikedilger/chorus.git
synced 2026-04-03 06:41:16 +00:00
TLS optional
This commit is contained in:
parent
d2191f9699
commit
6cb037986e
111
src/main.rs
111
src/main.rs
@ -12,15 +12,14 @@ use crate::config::Config;
|
||||
use crate::error::Error;
|
||||
use crate::globals::GLOBALS;
|
||||
use crate::store::Store;
|
||||
use crate::tls::MaybeTlsStream;
|
||||
use hyper::{Body, Request, Response};
|
||||
use rustls::{Certificate, PrivateKey};
|
||||
use std::env;
|
||||
use std::error::Error as StdError;
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::{BufReader, Read};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::{rustls, TlsAcceptor};
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Read;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Error> {
|
||||
@ -46,32 +45,12 @@ async fn main() -> Result<(), Error> {
|
||||
let _ = GLOBALS.store.set(store);
|
||||
|
||||
// TLS setup
|
||||
let tls_acceptor = {
|
||||
let certs: Vec<Certificate> =
|
||||
rustls_pemfile::certs(&mut BufReader::new(File::open(&config.certchain_pem_path)?))?
|
||||
.drain(..)
|
||||
.map(Certificate)
|
||||
.collect();
|
||||
|
||||
let mut keys: Vec<PrivateKey> = rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(
|
||||
File::open(&config.key_pem_path)?,
|
||||
))?
|
||||
.drain(..)
|
||||
.rev()
|
||||
.map(PrivateKey)
|
||||
.collect();
|
||||
|
||||
let key = match keys.pop() {
|
||||
Some(k) => k,
|
||||
None => return Err(Error::NoPrivateKey),
|
||||
};
|
||||
|
||||
let tls_config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?;
|
||||
|
||||
TlsAcceptor::from(Arc::new(tls_config))
|
||||
let maybe_tls_acceptor = if config.use_tls {
|
||||
log::info!("Using TLS");
|
||||
Some(tls::tls_acceptor(&config)?)
|
||||
} else {
|
||||
log::info!("Not using TLS");
|
||||
None
|
||||
};
|
||||
|
||||
// Bind listener to port
|
||||
@ -85,39 +64,51 @@ async fn main() -> Result<(), Error> {
|
||||
loop {
|
||||
let (tcp_stream, peer_addr) = listener.accept().await?;
|
||||
|
||||
let acceptor = tls_acceptor.clone();
|
||||
tokio::spawn(async move {
|
||||
match acceptor.accept(tcp_stream).await {
|
||||
Err(e) => log::error!("{}", e),
|
||||
Ok(tls_stream) => {
|
||||
let connection = GLOBALS
|
||||
.http_server
|
||||
.serve_connection(tls_stream, hyper::service::service_fn(handle_request));
|
||||
tokio::spawn(async move {
|
||||
// If our service exits with an error, log the error
|
||||
if let Err(he) = connection.await {
|
||||
if let Some(src) = he.source() {
|
||||
if &*format!("{}", src)
|
||||
== "Transport endpoint is not connected (os error 107)"
|
||||
{
|
||||
// do nothing
|
||||
} else {
|
||||
// Print in detail
|
||||
log::info!("{:?}", src);
|
||||
}
|
||||
} else {
|
||||
// Print in less detail
|
||||
let e: Error = he.into();
|
||||
log::info!("{}", e);
|
||||
}
|
||||
if let Some(tls_acceptor) = &maybe_tls_acceptor {
|
||||
let tls_acceptor_clone = tls_acceptor.clone();
|
||||
tokio::spawn(async move {
|
||||
match tls_acceptor_clone.accept(tcp_stream).await {
|
||||
Err(e) => log::error!("{}", e),
|
||||
Ok(tls_stream) => {
|
||||
if let Err(e) = serve(MaybeTlsStream::Rustls(tls_stream), peer_addr).await {
|
||||
log::error!("{}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
} else {
|
||||
serve(MaybeTlsStream::Plain(tcp_stream), peer_addr).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Serve a single network connection
|
||||
async fn serve(stream: MaybeTlsStream<TcpStream>, peer_addr: SocketAddr) -> Result<(), Error> {
|
||||
let connection = GLOBALS
|
||||
.http_server
|
||||
.serve_connection(stream, hyper::service::service_fn(handle_request));
|
||||
|
||||
tokio::spawn(async move {
|
||||
// If our service exits with an error, log the error
|
||||
if let Err(he) = connection.await {
|
||||
if let Some(src) = he.source() {
|
||||
if &*format!("{}", src) == "Transport endpoint is not connected (os error 107)" {
|
||||
// do nothing
|
||||
} else {
|
||||
// Print in detail
|
||||
log::error!("{:?}", src);
|
||||
}
|
||||
} else {
|
||||
// Print in less detail
|
||||
let e: Error = he.into();
|
||||
log::error!("{}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_request(_request: Request<Body>) -> Result<Response<Body>, Error> {
|
||||
web::serve_http().await
|
||||
}
|
||||
|
||||
42
src/tls.rs
42
src/tls.rs
@ -1,8 +1,40 @@
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use crate::config::Config;
|
||||
use crate::error::Error;
|
||||
use rustls::{Certificate, PrivateKey};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::{rustls, TlsAcceptor};
|
||||
|
||||
pub fn tls_acceptor(config: &Config) -> Result<TlsAcceptor, Error> {
|
||||
let certs: Vec<Certificate> =
|
||||
rustls_pemfile::certs(&mut BufReader::new(File::open(&config.certchain_pem_path)?))?
|
||||
.drain(..)
|
||||
.map(Certificate)
|
||||
.collect();
|
||||
|
||||
let mut keys: Vec<PrivateKey> =
|
||||
rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(File::open(&config.key_pem_path)?))?
|
||||
.drain(..)
|
||||
.rev()
|
||||
.map(PrivateKey)
|
||||
.collect();
|
||||
|
||||
let key = match keys.pop() {
|
||||
Some(k) => k,
|
||||
None => return Err(Error::NoPrivateKey),
|
||||
};
|
||||
|
||||
let tls_config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?;
|
||||
|
||||
Ok(TlsAcceptor::from(Arc::new(tls_config)))
|
||||
}
|
||||
|
||||
/// A stream that might be protected with TLS.
|
||||
#[allow(clippy::large_enum_variant)] // not great though
|
||||
@ -11,7 +43,7 @@ pub enum MaybeTlsStream<S> {
|
||||
/// Unencrypted socket stream.
|
||||
Plain(S),
|
||||
/// Encrypted socket stream using `rustls`.
|
||||
Rustls(tokio_rustls::client::TlsStream<S>),
|
||||
Rustls(tokio_rustls::server::TlsStream<S>),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user