signal-cli/src/main/java/org/asamk/signal/http/HttpServerHandler.java
2026-06-11 11:43:49 +02:00

369 lines
14 KiB
Java

package org.asamk.signal.http;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;
import org.asamk.signal.commands.Commands;
import org.asamk.signal.json.JsonReceiveMessageHandler;
import org.asamk.signal.jsonrpc.JsonRpcReader;
import org.asamk.signal.jsonrpc.JsonRpcResponse;
import org.asamk.signal.jsonrpc.JsonRpcSender;
import org.asamk.signal.jsonrpc.SignalJsonRpcCommandHandler;
import org.asamk.signal.manager.Manager;
import org.asamk.signal.manager.MultiAccountManager;
import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.util.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
public class HttpServerHandler implements AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(HttpServerHandler.class);
private final ObjectMapper objectMapper = Util.createJsonObjectMapper();
private final InetSocketAddress address;
private final SignalJsonRpcCommandHandler commandHandler;
private final MultiAccountManager c;
private final Manager m;
private HttpServer server;
private final AtomicBoolean shutdown = new AtomicBoolean(false);
private final Set<String> allowedHosts;
public HttpServerHandler(final InetSocketAddress address, final Manager m) {
this.address = address;
commandHandler = new SignalJsonRpcCommandHandler(m, Commands::getCommand);
this.c = null;
this.m = m;
this.allowedHosts = buildAllowedHosts(address);
}
public HttpServerHandler(final InetSocketAddress address, final MultiAccountManager c) {
this.address = address;
commandHandler = new SignalJsonRpcCommandHandler(c, Commands::getCommand);
this.c = c;
this.m = null;
this.allowedHosts = buildAllowedHosts(address);
}
public void init() throws IOException {
if (server != null) {
throw new AssertionError("HttpServerHandler already initialized");
}
logger.debug("Starting HTTP server on {}", address);
server = HttpServer.create(address, 0);
server.setExecutor(Executors.newVirtualThreadPerTaskExecutor());
server.createContext("/api/v1/rpc", this::handleRpcEndpoint);
server.createContext("/api/v1/events", this::handleEventsEndpoint);
server.createContext("/api/v1/check", this::handleCheckEndpoint);
server.start();
logger.info("Started HTTP server on {}", address);
// If we're listening on any local address (0.0.0.0 or ::), skip Host header validation
final var addr = address == null ? null : address.getAddress();
if (addr != null && addr.isAnyLocalAddress()) {
logger.warn("HTTP server has no authentication; Host header validation DISABLED because listening on {}", address);
} else {
logger.warn("HTTP server has no authentication; Host header is pinned to {}", allowedHosts);
}
}
@Override
public void close() {
if (server != null) {
shutdown.set(true);
synchronized (this) {
this.notifyAll();
}
// Increase this delay when https://bugs.openjdk.org/browse/JDK-8304065 is fixed
server.stop(2);
server = null;
shutdown.set(false);
}
}
private void sendResponse(int status, Object response, HttpExchange httpExchange) throws IOException {
if (response != null) {
final var byteResponse = objectMapper.writeValueAsBytes(response);
httpExchange.getResponseHeaders().add("Content-Type", "application/json");
httpExchange.sendResponseHeaders(status, byteResponse.length);
httpExchange.getResponseBody().write(byteResponse);
} else {
httpExchange.sendResponseHeaders(status, -1);
}
httpExchange.getResponseBody().close();
}
private void handleRpcEndpoint(HttpExchange httpExchange) throws IOException {
if (!isHostAllowed(httpExchange)) {
logger.warn("Rejected RPC request with invalid Host header: {} from {}",
httpExchange.getRequestHeaders().getFirst("Host"), httpExchange.getRemoteAddress());
sendResponse(421, null, httpExchange);
return;
}
if (!"/api/v1/rpc".equals(httpExchange.getRequestURI().getPath())) {
sendResponse(404, null, httpExchange);
return;
}
if (!"POST".equals(httpExchange.getRequestMethod())) {
sendResponse(405, null, httpExchange);
return;
}
final var contentType = httpExchange.getRequestHeaders().getFirst("Content-Type");
if (contentType == null || !contentType.startsWith("application/json")) {
sendResponse(415, null, httpExchange);
return;
}
try {
final Object[] result = {null};
final var jsonRpcSender = new JsonRpcSender(s -> {
if (result[0] != null) {
throw new AssertionError("There should only be a single JSON-RPC response");
}
result[0] = s;
});
final var jsonRpcReader = new JsonRpcReader(jsonRpcSender, httpExchange.getRequestBody());
jsonRpcReader.readMessages((method, params) -> commandHandler.handleRequest(objectMapper, method, params),
response -> logger.debug("Received unexpected response for id {}", response.getId()));
if (result[0] != null) {
sendResponse(200, result[0], httpExchange);
} else {
sendResponse(201, null, httpExchange);
}
} catch (Throwable aEx) {
logger.error("Failed to process request.", aEx);
sendResponse(200,
JsonRpcResponse.forError(new JsonRpcResponse.Error(JsonRpcResponse.Error.INTERNAL_ERROR,
"An internal server error has occurred.",
null), null),
httpExchange);
}
}
private void handleEventsEndpoint(HttpExchange httpExchange) throws IOException {
if (!isHostAllowed(httpExchange)) {
logger.warn("Rejected Events request with invalid Host header: {} from {}",
httpExchange.getRequestHeaders().getFirst("Host"), httpExchange.getRemoteAddress());
sendResponse(421, null, httpExchange);
return;
}
if (!"/api/v1/events".equals(httpExchange.getRequestURI().getPath())) {
sendResponse(404, null, httpExchange);
return;
}
if (!"GET".equals(httpExchange.getRequestMethod())) {
sendResponse(405, null, httpExchange);
return;
}
try {
final var queryString = httpExchange.getRequestURI().getRawQuery();
final var query = queryString == null ? Map.<String, String>of() : Util.getQueryMap(queryString);
List<Manager> managers = getManagerFromQuery(query);
if (managers == null) {
sendResponse(400, null, httpExchange);
return;
}
httpExchange.getResponseHeaders().add("Content-Type", "text/event-stream");
httpExchange.sendResponseHeaders(200, 0);
final var sender = new ServerSentEventSender(httpExchange.getResponseBody());
// Flush HTTP response headers to the client immediately.
// Without this, the JVM HttpServer buffers everything until a later write
// in the keep-alive loop (15 s), causing clients with shorter timeouts
// (e.g. 10 s) to abort before receiving the initial response.
httpExchange.getResponseBody().flush();
final var shouldStop = new AtomicBoolean(false);
final var handlers = subscribeReceiveHandlers(managers, sender, () -> {
shouldStop.set(true);
synchronized (this) {
this.notifyAll();
}
});
try {
while (true) {
synchronized (this) {
wait(15_000);
}
if (shouldStop.get() || shutdown.get()) {
break;
}
try {
sender.sendKeepAlive();
} catch (IOException e) {
break;
}
}
} finally {
for (final var pair : handlers) {
unsubscribeReceiveHandler(pair);
}
try {
httpExchange.getResponseBody().close();
} catch (IOException ignored) {
}
}
} catch (Throwable aEx) {
logger.error("Failed to process request.", aEx);
sendResponse(500, null, httpExchange);
}
}
private void handleCheckEndpoint(HttpExchange httpExchange) throws IOException {
if (!"/api/v1/check".equals(httpExchange.getRequestURI().getPath())) {
sendResponse(404, null, httpExchange);
return;
}
if (!"GET".equals(httpExchange.getRequestMethod())) {
sendResponse(405, null, httpExchange);
return;
}
sendResponse(200, null, httpExchange);
}
private List<Manager> getManagerFromQuery(final Map<String, String> query) {
if (m != null) {
return List.of(m);
}
if (c != null) {
final var account = query.get("account");
if (account == null || account.isEmpty()) {
return c.getManagers();
} else {
final var manager = c.getManager(account);
if (manager == null) {
return null;
}
return List.of(manager);
}
}
throw new AssertionError("Unreachable state");
}
private List<Pair<Manager, Manager.ReceiveMessageHandler>> subscribeReceiveHandlers(
final List<Manager> managers,
final ServerSentEventSender sender,
Callable unsubscribe
) {
return managers.stream().map(m1 -> {
final var receiveMessageHandler = new JsonReceiveMessageHandler(m1, s -> {
try {
sender.sendEvent(null, "receive", List.of(objectMapper.writeValueAsString(s)));
} catch (IOException e) {
unsubscribe.call();
}
});
m1.addReceiveHandler(receiveMessageHandler);
return new Pair<>(m1, (Manager.ReceiveMessageHandler) receiveMessageHandler);
}).toList();
}
private void unsubscribeReceiveHandler(final Pair<Manager, Manager.ReceiveMessageHandler> pair) {
final var m = pair.first();
final var handler = pair.second();
m.removeReceiveHandler(handler);
}
private interface Callable {
void call();
}
private Set<String> buildAllowedHosts(final InetSocketAddress address) {
final var s = new HashSet<String>();
final var host = address == null ? null : address.getHostString();
if (host != null && !host.isEmpty()) {
s.add(host.toLowerCase(Locale.ROOT));
}
s.add("localhost");
s.add("127.0.0.1");
s.add("::1");
return s;
}
private boolean isHostAllowed(final HttpExchange httpExchange) {
// If the server is bound to any local address (0.0.0.0 or ::), skip host header validation
if (address != null) {
final var addr = address.getAddress();
if (addr != null && addr.isAnyLocalAddress()) {
return true;
}
final var hostStr = address.getHostString();
if ("0.0.0.0".equals(hostStr) || "::".equals(hostStr)) {
return true;
}
}
final var hostHeader = httpExchange.getRequestHeaders().getFirst("Host");
if (hostHeader == null || hostHeader.isEmpty()) {
return false;
}
String hostPart = hostHeader;
String portPart = null;
if (hostHeader.startsWith("[")) {
final var idx = hostHeader.indexOf(']');
if (idx == -1) return false;
hostPart = hostHeader.substring(1, idx);
if (hostHeader.length() > idx + 1 && hostHeader.charAt(idx + 1) == ':') {
portPart = hostHeader.substring(idx + 2);
}
} else {
final var colon = hostHeader.lastIndexOf(':');
if (colon != -1) {
final var possiblePort = hostHeader.substring(colon + 1);
if (possiblePort.chars().allMatch(Character::isDigit)) {
hostPart = hostHeader.substring(0, colon);
portPart = possiblePort;
}
}
}
hostPart = hostPart.toLowerCase(Locale.ROOT);
if (!allowedHosts.contains(hostPart)) {
return false;
}
if (portPart != null) {
try {
final var port = Integer.parseInt(portPart);
if (port != address.getPort()) return false;
} catch (NumberFormatException e) {
return false;
}
}
return true;
}
}