diff --git a/src/main.rs b/src/main.rs index d02dd28..37ba97c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,12 +13,16 @@ use crate::error::Error; use crate::globals::GLOBALS; use crate::store::Store; use crate::tls::MaybeTlsStream; +use hyper::service::Service; use hyper::{Body, Request, Response}; use std::env; use std::error::Error as StdError; use std::fs::OpenOptions; +use std::future::Future; use std::io::Read; use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; use tokio::net::{TcpListener, TcpStream}; #[tokio::main] @@ -84,9 +88,10 @@ async fn main() -> Result<(), Error> { // Serve a single network connection async fn serve(stream: MaybeTlsStream, peer_addr: SocketAddr) -> Result<(), Error> { - let connection = GLOBALS - .http_server - .serve_connection(stream, hyper::service::service_fn(handle_request)); + // Serve the network stream with our http server and our HttpService + let service = HttpService { peer: peer_addr }; + + let connection = GLOBALS.http_server.serve_connection(stream, service); tokio::spawn(async move { // If our service exits with an error, log the error @@ -109,7 +114,32 @@ async fn serve(stream: MaybeTlsStream, peer_addr: SocketAddr) -> Resu Ok(()) } -async fn handle_request(request: Request) -> Result, Error> { +// This is our per-connection HTTP service +struct HttpService { + peer: SocketAddr, +} + +impl Service> for HttpService { + type Response = Response; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + // This is called for each HTTP request made by the client + // NOTE: it is not called for each websocket message once upgraded. + fn call(&mut self, req: Request) -> Self::Future { + let peer = self.peer; + Box::pin(async move { handle_http_request(peer, req).await }) + } +} + +async fn handle_http_request( + _peer: SocketAddr, + request: Request, +) -> Result, Error> { // check for Accept header of application/nostr+json if let Some(accept) = request.headers().get("Accept") { if let Ok(s) = accept.to_str() {