navidrome/plugins/host_websocket_test.go
Deluan Quintão 8f0b4930ff
refactor(conf): replace eager dir creation with lazy Dir type (#5495)
* feat(conf): add Dir type with lazy directory creation

Introduces the Dir type that wraps a directory path string and defers
os.MkdirAll until the first call to Path() or MustPath(), using sync.Once
to ensure the creation happens exactly once. Implements fmt.Stringer,
encoding.TextMarshaler, and encoding.TextUnmarshaler for config integration.
Includes Ginkgo/Gomega tests covering all methods and error paths.

* refactor(conf): replace eager dir creation with lazy Dir type

Change DataFolder, CacheFolder, Plugins.Folder, and Backup.Path from
string to Dir. Remove all os.MkdirAll calls from Load() so directories
are created lazily on first Path()/MustPath() call. Artwork folder
creation was already handled at point-of-use in image_upload.go.

Add SnapshotConfig() to conf package for safe test config save/restore
that avoids copying sync.Once inside Dir fields. Fix copy-lock vet
warning in nativeapi/config.go by marshalling pointer instead of value.

* refactor(conf): migrate tests and db init to lazy Dir type

Update all test files to use conf.NewDir() for Dir field assignments.
Ensure DataFolder is created lazily when the database is first opened
in db.Db(). Remove eager directory creation from conf.Load() tests.

* fix(conf): address review findings for Dir type

- Use os.ModePerm for DataFolder/CacheFolder (was 0700, should match
  original behavior). Add NewDirWithPerm for PluginsFolder (0700).
- Use Path() instead of MustPath() in db.Prune() to avoid logFatal
  from background cron job.
- Panic on marshal/unmarshal errors in SnapshotConfig (test helper).
- Clean up redundant String()/MustPath() calls in plugin manager.
- Remove dead code in dir_test.go.

Signed-off-by: Deluan <deluan@navidrome.org>

* fix(conf): add GoString to Dir for clean config dump output

Implement fmt.GoStringer on Dir so pretty.Sprintf shows the path
string instead of internal struct fields (sync.Once, perm, err).
Also add TODO comment to configtest about removing the indirection.

* fix(dir): improve error logging in MustPath method

Signed-off-by: Deluan <deluan@navidrome.org>

* refactor(tests): remove redundant tests for unwritable DataFolder and CacheFolder

Signed-off-by: Deluan <deluan@navidrome.org>

* fix(conf): address PR review feedback

- Ensure Plugins.Folder always uses 0700, even when user-configured
  (previously only the derived default got restrictive permissions).
- Create LogFile parent directory before opening, so LogFile paths
  inside a not-yet-created DataFolder work correctly.

---------

Signed-off-by: Deluan <deluan@navidrome.org>
2026-05-13 17:44:22 -03:00

617 lines
18 KiB
Go

//go:build !windows
package plugins
import (
"context"
"crypto/sha256"
"encoding/hex"
"maps"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"github.com/gorilla/websocket"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/conf/configtest"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/tests"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("WebSocketService", Ordered, func() {
var (
manager *Manager
tmpDir string
testService *testableWebSocketService
)
BeforeAll(func() {
var err error
tmpDir, err = os.MkdirTemp("", "websocket-test-*")
Expect(err).ToNot(HaveOccurred())
// Copy the test-websocket plugin
srcPath := filepath.Join(testdataDir, "test-websocket"+PackageExtension)
destPath := filepath.Join(tmpDir, "test-websocket"+PackageExtension)
data, err := os.ReadFile(srcPath)
Expect(err).ToNot(HaveOccurred())
err = os.WriteFile(destPath, data, 0600)
Expect(err).ToNot(HaveOccurred())
// Compute SHA256 for the plugin
hash := sha256.Sum256(data)
hashHex := hex.EncodeToString(hash[:])
// Setup config
DeferCleanup(configtest.SetupConfig())
conf.Server.Plugins.Enabled = true
conf.Server.Plugins.Folder = conf.NewDir(tmpDir)
conf.Server.Plugins.AutoReload = false
// Setup mock DataStore with pre-enabled plugin
mockPluginRepo := tests.CreateMockPluginRepo()
mockPluginRepo.Permitted = true
mockPluginRepo.SetData(model.Plugins{{
ID: "test-websocket",
Path: destPath,
SHA256: hashHex,
Enabled: true,
}})
dataStore := &tests.MockDataStore{MockedPlugin: mockPluginRepo}
// Create and start manager
manager = &Manager{
plugins: make(map[string]*plugin),
ds: dataStore,
subsonicRouter: http.NotFoundHandler(),
metrics: noopMetricsRecorder{},
}
err = manager.Start(GinkgoT().Context())
Expect(err).ToNot(HaveOccurred())
// Get WebSocket service from plugin's closers and wrap it for testing
service := findWebSocketService(manager, "test-websocket")
Expect(service).ToNot(BeNil())
testService = &testableWebSocketService{webSocketServiceImpl: service}
DeferCleanup(func() {
_ = manager.Stop()
_ = os.RemoveAll(tmpDir)
})
})
BeforeEach(func() {
// Clean up any connections from previous tests
testService.closeAllConnections()
})
Describe("Plugin Loading", func() {
It("should detect WebSocket capability", func() {
names := manager.PluginNames(string(CapabilityWebSocket))
Expect(names).To(ContainElement("test-websocket"))
})
It("should register WebSocket service for plugin", func() {
service := findWebSocketService(manager, "test-websocket")
Expect(service).ToNot(BeNil())
})
})
Describe("URL Validation", func() {
It("should reject invalid URL schemes", func() {
ctx := GinkgoT().Context()
_, err := testService.Connect(ctx, "http://example.com", nil, "test-conn")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("invalid URL scheme"))
})
It("should reject disallowed hosts", func() {
ctx := GinkgoT().Context()
_, err := testService.Connect(ctx, "wss://evil.com/socket", nil, "test-conn")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("not allowed"))
})
It("should allow hosts matching wildcard patterns", func() {
// test-websocket manifest allows *.example.com
// The pattern *.example.com matches any host ending with .example.com
ctx := context.Background()
allowed := testService.isHostAllowed("api.example.com")
Expect(allowed).To(BeTrue())
// Deep subdomains also match (ends with .example.com)
allowed = testService.isHostAllowed("sub.api.example.com")
Expect(allowed).To(BeTrue())
// But exact match without subdomain doesn't match *.example.com
allowed = testService.isHostAllowed("example.com")
Expect(allowed).To(BeFalse())
_ = ctx
})
It("should allow exact host matches", func() {
// test-websocket manifest allows echo.websocket.org
allowed := testService.isHostAllowed("echo.websocket.org")
Expect(allowed).To(BeTrue())
allowed = testService.isHostAllowed("other.org")
Expect(allowed).To(BeFalse())
})
It("should strip port before checking host", func() {
// Implementation strips port before matching against patterns
// test-websocket manifest has "localhost:*" which matches "localhost"
// after port stripping
// Note: The port wildcard pattern isn't actually implemented, but
// since port is stripped, "localhost:*" is compared against "localhost"
// which won't match. To make localhost work, we'd need exact "localhost"
// in the allowed hosts list.
// Testing that port is properly stripped
// The pattern "localhost:*" won't match "localhost" due to exact match
allowed := testService.isHostAllowed("localhost:8080")
Expect(allowed).To(BeFalse())
})
})
Describe("Connection Management", func() {
var wsServer *httptest.Server
var serverMessages []string
var serverMu sync.Mutex
BeforeEach(func() {
serverMessages = nil
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// Read messages until connection closes
for {
_, msg, err := conn.ReadMessage()
if err != nil {
break
}
serverMu.Lock()
serverMessages = append(serverMessages, string(msg))
serverMu.Unlock()
}
}))
// Add the server's host to allowed hosts for testing
// Since the implementation strips port before matching, we need to add
// the host without port
serverURL := strings.TrimPrefix(wsServer.URL, "http://")
hostOnly := serverURL
if idx := strings.LastIndex(serverURL, ":"); idx != -1 {
hostOnly = serverURL[:idx]
}
testService.requiredHosts = append(testService.requiredHosts, hostOnly)
})
AfterEach(func() {
testService.closeAllConnections()
if wsServer != nil {
wsServer.Close()
}
})
It("should connect to WebSocket server", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
connID, err := testService.Connect(ctx, wsURL, nil, "test-conn")
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal("test-conn"))
Expect(testService.getConnectionCount()).To(Equal(1))
})
It("should generate connection ID when not provided", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
connID, err := testService.Connect(ctx, wsURL, nil, "")
Expect(err).ToNot(HaveOccurred())
Expect(connID).ToNot(BeEmpty())
})
It("should reject duplicate connection IDs", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
_, err := testService.Connect(ctx, wsURL, nil, "dup-conn")
Expect(err).ToNot(HaveOccurred())
_, err = testService.Connect(ctx, wsURL, nil, "dup-conn")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("already exists"))
})
It("should send text messages", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
connID, err := testService.Connect(ctx, wsURL, nil, "send-text-conn")
Expect(err).ToNot(HaveOccurred())
err = testService.SendText(ctx, connID, "hello world")
Expect(err).ToNot(HaveOccurred())
// Give server time to receive the message
Eventually(func() []string {
serverMu.Lock()
defer serverMu.Unlock()
return serverMessages
}).Should(ContainElement("hello world"))
})
It("should send binary messages", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
connID, err := testService.Connect(ctx, wsURL, nil, "send-binary-conn")
Expect(err).ToNot(HaveOccurred())
binaryData := []byte{0x00, 0x01, 0x02, 0x03}
err = testService.SendBinary(ctx, connID, binaryData)
Expect(err).ToNot(HaveOccurred())
// Give server time to receive the message
Eventually(func() []string {
serverMu.Lock()
defer serverMu.Unlock()
return serverMessages
}).Should(ContainElement(string(binaryData)))
})
It("should close connections", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
connID, err := testService.Connect(ctx, wsURL, nil, "close-conn")
Expect(err).ToNot(HaveOccurred())
Expect(testService.getConnectionCount()).To(Equal(1))
err = testService.CloseConnection(ctx, connID, 1000, "normal close")
Expect(err).ToNot(HaveOccurred())
Expect(testService.getConnectionCount()).To(Equal(0))
})
It("should return error for non-existent connection", func() {
ctx := GinkgoT().Context()
err := testService.SendText(ctx, "non-existent", "message")
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("not found"))
})
})
Describe("Plugin Callbacks", func() {
var wsServer *httptest.Server
var serverConn *websocket.Conn
var serverMessages []string
var serverBinaryMessages [][]byte
var serverMu sync.Mutex
BeforeEach(func() {
serverConn = nil
serverMessages = nil
serverBinaryMessages = nil
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
serverMu.Lock()
serverConn = conn
serverMu.Unlock()
// Read and store messages
for {
msgType, msg, err := conn.ReadMessage()
if err != nil {
break
}
serverMu.Lock()
if msgType == websocket.BinaryMessage {
serverBinaryMessages = append(serverBinaryMessages, msg)
} else {
serverMessages = append(serverMessages, string(msg))
}
serverMu.Unlock()
}
}))
serverURL := strings.TrimPrefix(wsServer.URL, "http://")
hostOnly := serverURL
if idx := strings.LastIndex(serverURL, ":"); idx != -1 {
hostOnly = serverURL[:idx]
}
testService.requiredHosts = append(testService.requiredHosts, hostOnly)
})
AfterEach(func() {
testService.closeAllConnections()
if wsServer != nil {
wsServer.Close()
}
})
It("should invoke OnBinaryMessage callback when receiving binary", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
_, err := testService.Connect(ctx, wsURL, nil, "binary-cb-conn")
Expect(err).ToNot(HaveOccurred())
// Wait for server to have the connection
Eventually(func() *websocket.Conn {
serverMu.Lock()
defer serverMu.Unlock()
return serverConn
}).ShouldNot(BeNil())
// Send binary message from server to plugin
binaryData := []byte{0xDE, 0xAD, 0xBE, 0xEF}
serverMu.Lock()
err = serverConn.WriteMessage(websocket.BinaryMessage, binaryData)
serverMu.Unlock()
Expect(err).ToNot(HaveOccurred())
// Plugin echoes binary data back as a binary message
Eventually(func() [][]byte {
serverMu.Lock()
defer serverMu.Unlock()
return serverBinaryMessages
}).Should(ContainElement(binaryData))
})
It("should invoke OnClose callback when server closes connection", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
_, err := testService.Connect(ctx, wsURL, nil, "close-cb-conn")
Expect(err).ToNot(HaveOccurred())
// Wait for server to have the connection
Eventually(func() *websocket.Conn {
serverMu.Lock()
defer serverMu.Unlock()
return serverConn
}).ShouldNot(BeNil())
// Close from server side
serverMu.Lock()
_ = serverConn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "goodbye"))
serverConn.Close()
serverMu.Unlock()
// Connection should be removed after close callback
Eventually(func() int {
return testService.getConnectionCount()
}).Should(Equal(0))
})
})
Describe("Plugin Host Function Calls", func() {
var wsServer *httptest.Server
var serverConn *websocket.Conn
var serverMessages []string
var serverMu sync.Mutex
BeforeEach(func() {
serverMessages = nil
serverConn = nil
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
serverMu.Lock()
serverConn = conn
serverMu.Unlock()
// Read and store messages
for {
_, msg, err := conn.ReadMessage()
if err != nil {
break
}
serverMu.Lock()
serverMessages = append(serverMessages, string(msg))
serverMu.Unlock()
}
}))
serverURL := strings.TrimPrefix(wsServer.URL, "http://")
hostOnly := serverURL
if idx := strings.LastIndex(serverURL, ":"); idx != -1 {
hostOnly = serverURL[:idx]
}
testService.requiredHosts = append(testService.requiredHosts, hostOnly)
})
AfterEach(func() {
testService.closeAllConnections()
if wsServer != nil {
wsServer.Close()
}
})
It("should allow plugin to send messages via host function", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
_, err := testService.Connect(ctx, wsURL, nil, "host-send-conn")
Expect(err).ToNot(HaveOccurred())
// Wait for server to have the connection
Eventually(func() *websocket.Conn {
serverMu.Lock()
defer serverMu.Unlock()
return serverConn
}).ShouldNot(BeNil())
// Server sends "echo" message to trigger plugin to echo back
serverMu.Lock()
err = serverConn.WriteMessage(websocket.TextMessage, []byte("echo"))
serverMu.Unlock()
Expect(err).ToNot(HaveOccurred())
// Plugin should have echoed back via host function
Eventually(func() []string {
serverMu.Lock()
defer serverMu.Unlock()
return serverMessages
}).Should(ContainElement("echo:echo"))
})
It("should allow plugin to close connection via host function", func() {
ctx := GinkgoT().Context()
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
_, err := testService.Connect(ctx, wsURL, nil, "host-close-conn")
Expect(err).ToNot(HaveOccurred())
Expect(testService.getConnectionCount()).To(Equal(1))
// Wait for server to have the connection
Eventually(func() *websocket.Conn {
serverMu.Lock()
defer serverMu.Unlock()
return serverConn
}).ShouldNot(BeNil())
// Server sends "close" message to trigger plugin to close connection
serverMu.Lock()
err = serverConn.WriteMessage(websocket.TextMessage, []byte("close"))
serverMu.Unlock()
Expect(err).ToNot(HaveOccurred())
// Connection should be closed by plugin
Eventually(func() int {
return testService.getConnectionCount()
}).Should(Equal(0))
})
})
Describe("Plugin Unload", func() {
It("should close all connections when plugin is unloaded", func() {
// Create a fresh server for this test
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// Keep alive
for {
_, _, err := conn.ReadMessage()
if err != nil {
break
}
}
}))
defer wsServer.Close()
serverURL := strings.TrimPrefix(wsServer.URL, "http://")
hostOnly := serverURL
if idx := strings.LastIndex(serverURL, ":"); idx != -1 {
hostOnly = serverURL[:idx]
}
testService.requiredHosts = append(testService.requiredHosts, hostOnly)
ctx := GinkgoT().Context()
wsURL := "ws://" + serverURL
// Create multiple connections
_, err := testService.Connect(ctx, wsURL, nil, "unload-conn-1")
Expect(err).ToNot(HaveOccurred())
_, err = testService.Connect(ctx, wsURL, nil, "unload-conn-2")
Expect(err).ToNot(HaveOccurred())
Expect(testService.getConnectionCount()).To(Equal(2))
// Close the service (simulates plugin unload)
err = testService.Close()
Expect(err).ToNot(HaveOccurred())
Expect(testService.getConnectionCount()).To(Equal(0))
})
})
Describe("matchHostPattern", func() {
It("should match exact hosts", func() {
Expect(matchHostPattern("example.com", "example.com")).To(BeTrue())
Expect(matchHostPattern("example.com", "other.com")).To(BeFalse())
})
It("should match wildcard patterns", func() {
Expect(matchHostPattern("*.example.com", "api.example.com")).To(BeTrue())
Expect(matchHostPattern("*.example.com", "example.com")).To(BeFalse())
Expect(matchHostPattern("*.example.com", "deep.api.example.com")).To(BeTrue())
})
It("should match bare '*' as allow-all", func() {
Expect(matchHostPattern("*", "anything.example.com")).To(BeTrue())
Expect(matchHostPattern("*", "127.0.0.1")).To(BeTrue())
Expect(matchHostPattern("*", "::1")).To(BeTrue())
})
It("should not match partial patterns", func() {
Expect(matchHostPattern("*.example.com", "example.com.evil.org")).To(BeFalse())
})
})
})
// testableWebSocketService wraps webSocketServiceImpl with test helpers.
type testableWebSocketService struct {
*webSocketServiceImpl
}
func (t *testableWebSocketService) getConnectionCount() int {
t.mu.RLock()
defer t.mu.RUnlock()
return len(t.connections)
}
func (t *testableWebSocketService) closeAllConnections() {
t.mu.Lock()
conns := make(map[string]*wsConnection, len(t.connections))
maps.Copy(conns, t.connections)
t.connections = make(map[string]*wsConnection)
t.mu.Unlock()
for _, conn := range conns {
conn.closeMu.Lock()
conn.isClosed = true
conn.closeMu.Unlock()
_ = conn.conn.Close()
close(conn.done)
}
}
// findWebSocketService finds the WebSocket service from a plugin's closers.
func findWebSocketService(m *Manager, pluginName string) *webSocketServiceImpl {
m.mu.RLock()
instance, ok := m.plugins[pluginName]
m.mu.RUnlock()
if !ok {
return nil
}
for _, closer := range instance.closers {
if svc, ok := closer.(*webSocketServiceImpl); ok {
return svc
}
}
return nil
}