//go:build !windows package plugins import ( "context" "crypto/sha256" "encoding/base64" "encoding/hex" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "sync" "time" "github.com/gorilla/websocket" "github.com/navidrome/navidrome/conf" "github.com/navidrome/navidrome/conf/configtest" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/tests" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("WebSocketService", Ordered, func() { var ( manager *Manager tmpDir string testService *testableWebSocketService ) BeforeAll(func() { var err error tmpDir, err = os.MkdirTemp("", "websocket-test-*") Expect(err).ToNot(HaveOccurred()) // Copy the test-websocket plugin srcPath := filepath.Join(testdataDir, "test-websocket"+PackageExtension) destPath := filepath.Join(tmpDir, "test-websocket"+PackageExtension) data, err := os.ReadFile(srcPath) Expect(err).ToNot(HaveOccurred()) err = os.WriteFile(destPath, data, 0600) Expect(err).ToNot(HaveOccurred()) // Compute SHA256 for the plugin hash := sha256.Sum256(data) hashHex := hex.EncodeToString(hash[:]) // Setup config DeferCleanup(configtest.SetupConfig()) conf.Server.Plugins.Enabled = true conf.Server.Plugins.Folder = tmpDir conf.Server.Plugins.AutoReload = false conf.Server.CacheFolder = filepath.Join(tmpDir, "cache") // Setup mock DataStore with pre-enabled plugin mockPluginRepo := tests.CreateMockPluginRepo() mockPluginRepo.Permitted = true mockPluginRepo.SetData(model.Plugins{{ ID: "test-websocket", Path: destPath, SHA256: hashHex, Enabled: true, }}) dataStore := &tests.MockDataStore{MockedPlugin: mockPluginRepo} // Create and start manager manager = &Manager{ plugins: make(map[string]*plugin), ds: dataStore, subsonicRouter: http.NotFoundHandler(), } err = manager.Start(GinkgoT().Context()) Expect(err).ToNot(HaveOccurred()) // Get WebSocket service from plugin's closers and wrap it for testing service := findWebSocketService(manager, "test-websocket") Expect(service).ToNot(BeNil()) testService = &testableWebSocketService{webSocketServiceImpl: service} DeferCleanup(func() { _ = manager.Stop() _ = os.RemoveAll(tmpDir) }) }) BeforeEach(func() { // Clean up any connections from previous tests testService.closeAllConnections() }) Describe("Plugin Loading", func() { It("should detect WebSocket capability", func() { names := manager.PluginNames(string(CapabilityWebSocket)) Expect(names).To(ContainElement("test-websocket")) }) It("should register WebSocket service for plugin", func() { service := findWebSocketService(manager, "test-websocket") Expect(service).ToNot(BeNil()) }) }) Describe("URL Validation", func() { It("should reject invalid URL schemes", func() { ctx := GinkgoT().Context() _, err := testService.Connect(ctx, "http://example.com", nil, "test-conn") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("invalid URL scheme")) }) It("should reject disallowed hosts", func() { ctx := GinkgoT().Context() _, err := testService.Connect(ctx, "wss://evil.com/socket", nil, "test-conn") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("not allowed")) }) It("should allow hosts matching wildcard patterns", func() { // test-websocket manifest allows *.example.com // The pattern *.example.com matches any host ending with .example.com ctx := context.Background() allowed := testService.isHostAllowed("api.example.com") Expect(allowed).To(BeTrue()) // Deep subdomains also match (ends with .example.com) allowed = testService.isHostAllowed("sub.api.example.com") Expect(allowed).To(BeTrue()) // But exact match without subdomain doesn't match *.example.com allowed = testService.isHostAllowed("example.com") Expect(allowed).To(BeFalse()) _ = ctx }) It("should allow exact host matches", func() { // test-websocket manifest allows echo.websocket.org allowed := testService.isHostAllowed("echo.websocket.org") Expect(allowed).To(BeTrue()) allowed = testService.isHostAllowed("other.org") Expect(allowed).To(BeFalse()) }) It("should strip port before checking host", func() { // Implementation strips port before matching against patterns // test-websocket manifest has "localhost:*" which matches "localhost" // after port stripping // Note: The port wildcard pattern isn't actually implemented, but // since port is stripped, "localhost:*" is compared against "localhost" // which won't match. To make localhost work, we'd need exact "localhost" // in the allowed hosts list. // Testing that port is properly stripped // The pattern "localhost:*" won't match "localhost" due to exact match allowed := testService.isHostAllowed("localhost:8080") Expect(allowed).To(BeFalse()) }) }) Describe("Connection Management", func() { var wsServer *httptest.Server var serverMessages []string var serverMu sync.Mutex BeforeEach(func() { serverMessages = nil upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } // Read messages until connection closes for { _, msg, err := conn.ReadMessage() if err != nil { break } serverMu.Lock() serverMessages = append(serverMessages, string(msg)) serverMu.Unlock() } })) // Add the server's host to allowed hosts for testing // Since the implementation strips port before matching, we need to add // the host without port serverURL := strings.TrimPrefix(wsServer.URL, "http://") hostOnly := serverURL if idx := strings.LastIndex(serverURL, ":"); idx != -1 { hostOnly = serverURL[:idx] } testService.allowedHosts = append(testService.allowedHosts, hostOnly) }) AfterEach(func() { testService.closeAllConnections() if wsServer != nil { wsServer.Close() } }) It("should connect to WebSocket server", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") connID, err := testService.Connect(ctx, wsURL, nil, "test-conn") Expect(err).ToNot(HaveOccurred()) Expect(connID).To(Equal("test-conn")) Expect(testService.getConnectionCount()).To(Equal(1)) }) It("should generate connection ID when not provided", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") connID, err := testService.Connect(ctx, wsURL, nil, "") Expect(err).ToNot(HaveOccurred()) Expect(connID).ToNot(BeEmpty()) }) It("should reject duplicate connection IDs", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") _, err := testService.Connect(ctx, wsURL, nil, "dup-conn") Expect(err).ToNot(HaveOccurred()) _, err = testService.Connect(ctx, wsURL, nil, "dup-conn") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("already exists")) }) It("should send text messages", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") connID, err := testService.Connect(ctx, wsURL, nil, "send-text-conn") Expect(err).ToNot(HaveOccurred()) err = testService.SendText(ctx, connID, "hello world") Expect(err).ToNot(HaveOccurred()) // Give server time to receive the message Eventually(func() []string { serverMu.Lock() defer serverMu.Unlock() return serverMessages }).Should(ContainElement("hello world")) }) It("should send binary messages", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") connID, err := testService.Connect(ctx, wsURL, nil, "send-binary-conn") Expect(err).ToNot(HaveOccurred()) binaryData := []byte{0x00, 0x01, 0x02, 0x03} err = testService.SendBinary(ctx, connID, binaryData) Expect(err).ToNot(HaveOccurred()) // Give server time to receive the message Eventually(func() []string { serverMu.Lock() defer serverMu.Unlock() return serverMessages }).Should(ContainElement(string(binaryData))) }) It("should close connections", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") connID, err := testService.Connect(ctx, wsURL, nil, "close-conn") Expect(err).ToNot(HaveOccurred()) Expect(testService.getConnectionCount()).To(Equal(1)) err = testService.CloseConnection(ctx, connID, 1000, "normal close") Expect(err).ToNot(HaveOccurred()) Expect(testService.getConnectionCount()).To(Equal(0)) }) It("should return error for non-existent connection", func() { ctx := GinkgoT().Context() err := testService.SendText(ctx, "non-existent", "message") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("not found")) }) }) Describe("Plugin Callbacks", func() { var wsServer *httptest.Server var serverConn *websocket.Conn var serverMu sync.Mutex BeforeEach(func() { serverConn = nil upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } serverMu.Lock() serverConn = conn serverMu.Unlock() // Keep connection open for { _, _, err := conn.ReadMessage() if err != nil { break } } })) serverURL := strings.TrimPrefix(wsServer.URL, "http://") hostOnly := serverURL if idx := strings.LastIndex(serverURL, ":"); idx != -1 { hostOnly = serverURL[:idx] } testService.allowedHosts = append(testService.allowedHosts, hostOnly) }) AfterEach(func() { testService.closeAllConnections() if wsServer != nil { wsServer.Close() } }) It("should invoke OnTextMessage callback when receiving text", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") connID, err := testService.Connect(ctx, wsURL, nil, "text-cb-conn") Expect(err).ToNot(HaveOccurred()) // Wait for server to have the connection Eventually(func() *websocket.Conn { serverMu.Lock() defer serverMu.Unlock() return serverConn }).ShouldNot(BeNil()) // Send message from server to plugin serverMu.Lock() err = serverConn.WriteMessage(websocket.TextMessage, []byte("test message")) serverMu.Unlock() Expect(err).ToNot(HaveOccurred()) // The plugin should have received the callback // We can verify by checking the plugin's stored messages via vars // For now we just verify no errors occurred time.Sleep(100 * time.Millisecond) _ = connID }) It("should invoke OnBinaryMessage callback when receiving binary", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") connID, err := testService.Connect(ctx, wsURL, nil, "binary-cb-conn") Expect(err).ToNot(HaveOccurred()) // Wait for server to have the connection Eventually(func() *websocket.Conn { serverMu.Lock() defer serverMu.Unlock() return serverConn }).ShouldNot(BeNil()) // Send binary message from server to plugin binaryData := []byte{0xDE, 0xAD, 0xBE, 0xEF} serverMu.Lock() err = serverConn.WriteMessage(websocket.BinaryMessage, binaryData) serverMu.Unlock() Expect(err).ToNot(HaveOccurred()) // Give time for callback to execute time.Sleep(100 * time.Millisecond) _ = connID }) It("should invoke OnClose callback when server closes connection", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") _, err := testService.Connect(ctx, wsURL, nil, "close-cb-conn") Expect(err).ToNot(HaveOccurred()) // Wait for server to have the connection Eventually(func() *websocket.Conn { serverMu.Lock() defer serverMu.Unlock() return serverConn }).ShouldNot(BeNil()) // Close from server side serverMu.Lock() _ = serverConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "goodbye")) serverConn.Close() serverMu.Unlock() // Connection should be removed after close callback Eventually(func() int { return testService.getConnectionCount() }).Should(Equal(0)) }) }) Describe("Plugin Host Function Calls", func() { var wsServer *httptest.Server var serverConn *websocket.Conn var serverMessages []string var serverMu sync.Mutex BeforeEach(func() { serverMessages = nil serverConn = nil upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } serverMu.Lock() serverConn = conn serverMu.Unlock() // Read and store messages for { _, msg, err := conn.ReadMessage() if err != nil { break } serverMu.Lock() serverMessages = append(serverMessages, string(msg)) serverMu.Unlock() } })) serverURL := strings.TrimPrefix(wsServer.URL, "http://") hostOnly := serverURL if idx := strings.LastIndex(serverURL, ":"); idx != -1 { hostOnly = serverURL[:idx] } testService.allowedHosts = append(testService.allowedHosts, hostOnly) }) AfterEach(func() { testService.closeAllConnections() if wsServer != nil { wsServer.Close() } }) It("should allow plugin to send messages via host function", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") connID, err := testService.Connect(ctx, wsURL, nil, "host-send-conn") Expect(err).ToNot(HaveOccurred()) // Wait for server to have the connection Eventually(func() *websocket.Conn { serverMu.Lock() defer serverMu.Unlock() return serverConn }).ShouldNot(BeNil()) // Server sends "echo" message to trigger plugin to echo back serverMu.Lock() err = serverConn.WriteMessage(websocket.TextMessage, []byte("echo")) serverMu.Unlock() Expect(err).ToNot(HaveOccurred()) // Plugin should have echoed back via host function Eventually(func() []string { serverMu.Lock() defer serverMu.Unlock() return serverMessages }).Should(ContainElement("echo:echo")) _ = connID }) It("should allow plugin to close connection via host function", func() { ctx := GinkgoT().Context() wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://") _, err := testService.Connect(ctx, wsURL, nil, "host-close-conn") Expect(err).ToNot(HaveOccurred()) Expect(testService.getConnectionCount()).To(Equal(1)) // Wait for server to have the connection Eventually(func() *websocket.Conn { serverMu.Lock() defer serverMu.Unlock() return serverConn }).ShouldNot(BeNil()) // Server sends "close" message to trigger plugin to close connection serverMu.Lock() err = serverConn.WriteMessage(websocket.TextMessage, []byte("close")) serverMu.Unlock() Expect(err).ToNot(HaveOccurred()) // Connection should be closed by plugin Eventually(func() int { return testService.getConnectionCount() }).Should(Equal(0)) }) }) Describe("Plugin Unload", func() { It("should close all connections when plugin is unloaded", func() { // Create a fresh server for this test upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } // Keep alive for { _, _, err := conn.ReadMessage() if err != nil { break } } })) defer wsServer.Close() serverURL := strings.TrimPrefix(wsServer.URL, "http://") hostOnly := serverURL if idx := strings.LastIndex(serverURL, ":"); idx != -1 { hostOnly = serverURL[:idx] } testService.allowedHosts = append(testService.allowedHosts, hostOnly) ctx := GinkgoT().Context() wsURL := "ws://" + serverURL // Create multiple connections _, err := testService.Connect(ctx, wsURL, nil, "unload-conn-1") Expect(err).ToNot(HaveOccurred()) _, err = testService.Connect(ctx, wsURL, nil, "unload-conn-2") Expect(err).ToNot(HaveOccurred()) Expect(testService.getConnectionCount()).To(Equal(2)) // Close the service (simulates plugin unload) err = testService.Close() Expect(err).ToNot(HaveOccurred()) Expect(testService.getConnectionCount()).To(Equal(0)) }) }) Describe("matchHostPattern", func() { It("should match exact hosts", func() { Expect(matchHostPattern("example.com", "example.com")).To(BeTrue()) Expect(matchHostPattern("example.com", "other.com")).To(BeFalse()) }) It("should match wildcard patterns", func() { Expect(matchHostPattern("*.example.com", "api.example.com")).To(BeTrue()) Expect(matchHostPattern("*.example.com", "example.com")).To(BeFalse()) Expect(matchHostPattern("*.example.com", "deep.api.example.com")).To(BeTrue()) }) It("should not match partial patterns", func() { Expect(matchHostPattern("*.example.com", "example.com.evil.org")).To(BeFalse()) }) }) }) // testableWebSocketService wraps webSocketServiceImpl with test helpers. type testableWebSocketService struct { *webSocketServiceImpl } func (t *testableWebSocketService) getConnectionCount() int { t.mu.RLock() defer t.mu.RUnlock() return len(t.connections) } func (t *testableWebSocketService) closeAllConnections() { t.mu.Lock() conns := make(map[string]*wsConnection, len(t.connections)) for k, v := range t.connections { conns[k] = v } t.connections = make(map[string]*wsConnection) t.mu.Unlock() for _, conn := range conns { conn.closeMu.Lock() conn.isClosed = true conn.closeMu.Unlock() _ = conn.conn.Close() close(conn.done) } } // findWebSocketService finds the WebSocket service from a plugin's closers. func findWebSocketService(m *Manager, pluginName string) *webSocketServiceImpl { m.mu.RLock() instance, ok := m.plugins[pluginName] m.mu.RUnlock() if !ok { return nil } for _, closer := range instance.closers { if svc, ok := closer.(*webSocketServiceImpl); ok { return svc } } return nil } // Ensure base64 import is used var _ = base64.StdEncoding