diff --git a/plugins/host_websocket.go b/plugins/host_websocket.go index 64bdee484..7f946c73d 100644 --- a/plugins/host_websocket.go +++ b/plugins/host_websocket.go @@ -22,6 +22,9 @@ import ( // Detected when the plugin exports any of the WebSocket callback functions. const CapabilityWebSocket Capability = "WebSocket" +// webSocketCallbackTimeout is the maximum duration allowed for a WebSocket callback. +const webSocketCallbackTimeout = 30 * time.Second + // WebSocket callback function names const ( FuncWebSocketOnTextMessage = "nd_websocket_on_text_message" @@ -126,8 +129,11 @@ func (s *webSocketServiceImpl) Connect(ctx context.Context, urlStr string, heade s.connections[connectionID] = wsConn s.mu.Unlock() - // Start read goroutine - go s.readLoop(ctx, connectionID, wsConn) + // Start read goroutine with manager's context. + // We use manager.ctx instead of the caller's ctx because the readLoop must + // outlive the Connect() call. The manager's context is cancelled during + // application shutdown, ensuring graceful cleanup. + go s.readLoop(s.manager.ctx, connectionID, wsConn) log.Debug(ctx, "WebSocket connected", "plugin", s.pluginName, "connectionID", connectionID, "url", urlStr) return connectionID, nil @@ -329,8 +335,12 @@ func (s *webSocketServiceImpl) invokeOnTextMessage(ctx context.Context, connecti Message: message, } + // Create a timeout context for this callback invocation + callbackCtx, cancel := context.WithTimeout(ctx, webSocketCallbackTimeout) + defer cancel() + start := time.Now() - err := callPluginFunctionNoOutput(ctx, instance, FuncWebSocketOnTextMessage, input) + err := callPluginFunctionNoOutput(callbackCtx, instance, FuncWebSocketOnTextMessage, input) if err != nil { // Don't log error if function simply doesn't exist (optional callback) if !errors.Is(errFunctionNotFound, err) { @@ -350,8 +360,12 @@ func (s *webSocketServiceImpl) invokeOnBinaryMessage(ctx context.Context, connec Data: base64.StdEncoding.EncodeToString(data), } + // Create a timeout context for this callback invocation + callbackCtx, cancel := context.WithTimeout(ctx, webSocketCallbackTimeout) + defer cancel() + start := time.Now() - err := callPluginFunctionNoOutput(ctx, instance, FuncWebSocketOnBinaryMessage, input) + err := callPluginFunctionNoOutput(callbackCtx, instance, FuncWebSocketOnBinaryMessage, input) if err != nil { // Don't log error if function simply doesn't exist (optional callback) if !errors.Is(errFunctionNotFound, err) { @@ -371,8 +385,12 @@ func (s *webSocketServiceImpl) invokeOnError(ctx context.Context, connectionID, Error: errorMsg, } + // Create a timeout context for this callback invocation + callbackCtx, cancel := context.WithTimeout(ctx, webSocketCallbackTimeout) + defer cancel() + start := time.Now() - err := callPluginFunctionNoOutput(ctx, instance, FuncWebSocketOnError, input) + err := callPluginFunctionNoOutput(callbackCtx, instance, FuncWebSocketOnError, input) if err != nil { // Don't log error if function simply doesn't exist (optional callback) if !errors.Is(errFunctionNotFound, err) { @@ -393,8 +411,12 @@ func (s *webSocketServiceImpl) invokeOnClose(ctx context.Context, connectionID s Reason: reason, } + // Create a timeout context for this callback invocation + callbackCtx, cancel := context.WithTimeout(ctx, webSocketCallbackTimeout) + defer cancel() + start := time.Now() - err := callPluginFunctionNoOutput(ctx, instance, FuncWebSocketOnClose, input) + err := callPluginFunctionNoOutput(callbackCtx, instance, FuncWebSocketOnClose, input) if err != nil { // Don't log error if function simply doesn't exist (optional callback) if !errors.Is(errFunctionNotFound, err) {