mirror of
https://github.com/navidrome/navidrome.git
synced 2026-06-02 07:01:36 +00:00
* feat(conf): add Dir type with lazy directory creation Introduces the Dir type that wraps a directory path string and defers os.MkdirAll until the first call to Path() or MustPath(), using sync.Once to ensure the creation happens exactly once. Implements fmt.Stringer, encoding.TextMarshaler, and encoding.TextUnmarshaler for config integration. Includes Ginkgo/Gomega tests covering all methods and error paths. * refactor(conf): replace eager dir creation with lazy Dir type Change DataFolder, CacheFolder, Plugins.Folder, and Backup.Path from string to Dir. Remove all os.MkdirAll calls from Load() so directories are created lazily on first Path()/MustPath() call. Artwork folder creation was already handled at point-of-use in image_upload.go. Add SnapshotConfig() to conf package for safe test config save/restore that avoids copying sync.Once inside Dir fields. Fix copy-lock vet warning in nativeapi/config.go by marshalling pointer instead of value. * refactor(conf): migrate tests and db init to lazy Dir type Update all test files to use conf.NewDir() for Dir field assignments. Ensure DataFolder is created lazily when the database is first opened in db.Db(). Remove eager directory creation from conf.Load() tests. * fix(conf): address review findings for Dir type - Use os.ModePerm for DataFolder/CacheFolder (was 0700, should match original behavior). Add NewDirWithPerm for PluginsFolder (0700). - Use Path() instead of MustPath() in db.Prune() to avoid logFatal from background cron job. - Panic on marshal/unmarshal errors in SnapshotConfig (test helper). - Clean up redundant String()/MustPath() calls in plugin manager. - Remove dead code in dir_test.go. Signed-off-by: Deluan <deluan@navidrome.org> * fix(conf): add GoString to Dir for clean config dump output Implement fmt.GoStringer on Dir so pretty.Sprintf shows the path string instead of internal struct fields (sync.Once, perm, err). Also add TODO comment to configtest about removing the indirection. * fix(dir): improve error logging in MustPath method Signed-off-by: Deluan <deluan@navidrome.org> * refactor(tests): remove redundant tests for unwritable DataFolder and CacheFolder Signed-off-by: Deluan <deluan@navidrome.org> * fix(conf): address PR review feedback - Ensure Plugins.Folder always uses 0700, even when user-configured (previously only the derived default got restrictive permissions). - Create LogFile parent directory before opening, so LogFile paths inside a not-yet-created DataFolder work correctly. --------- Signed-off-by: Deluan <deluan@navidrome.org>
617 lines
18 KiB
Go
617 lines
18 KiB
Go
//go:build !windows
|
|
|
|
package plugins
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
|
|
"encoding/hex"
|
|
"maps"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/navidrome/navidrome/conf"
|
|
"github.com/navidrome/navidrome/conf/configtest"
|
|
"github.com/navidrome/navidrome/model"
|
|
"github.com/navidrome/navidrome/tests"
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
)
|
|
|
|
var _ = Describe("WebSocketService", Ordered, func() {
|
|
var (
|
|
manager *Manager
|
|
tmpDir string
|
|
testService *testableWebSocketService
|
|
)
|
|
|
|
BeforeAll(func() {
|
|
var err error
|
|
tmpDir, err = os.MkdirTemp("", "websocket-test-*")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Copy the test-websocket plugin
|
|
srcPath := filepath.Join(testdataDir, "test-websocket"+PackageExtension)
|
|
destPath := filepath.Join(tmpDir, "test-websocket"+PackageExtension)
|
|
data, err := os.ReadFile(srcPath)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
err = os.WriteFile(destPath, data, 0600)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Compute SHA256 for the plugin
|
|
hash := sha256.Sum256(data)
|
|
hashHex := hex.EncodeToString(hash[:])
|
|
|
|
// Setup config
|
|
DeferCleanup(configtest.SetupConfig())
|
|
conf.Server.Plugins.Enabled = true
|
|
conf.Server.Plugins.Folder = conf.NewDir(tmpDir)
|
|
conf.Server.Plugins.AutoReload = false
|
|
|
|
// Setup mock DataStore with pre-enabled plugin
|
|
mockPluginRepo := tests.CreateMockPluginRepo()
|
|
mockPluginRepo.Permitted = true
|
|
mockPluginRepo.SetData(model.Plugins{{
|
|
ID: "test-websocket",
|
|
Path: destPath,
|
|
SHA256: hashHex,
|
|
Enabled: true,
|
|
}})
|
|
dataStore := &tests.MockDataStore{MockedPlugin: mockPluginRepo}
|
|
|
|
// Create and start manager
|
|
manager = &Manager{
|
|
plugins: make(map[string]*plugin),
|
|
ds: dataStore,
|
|
subsonicRouter: http.NotFoundHandler(),
|
|
metrics: noopMetricsRecorder{},
|
|
}
|
|
err = manager.Start(GinkgoT().Context())
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Get WebSocket service from plugin's closers and wrap it for testing
|
|
service := findWebSocketService(manager, "test-websocket")
|
|
Expect(service).ToNot(BeNil())
|
|
testService = &testableWebSocketService{webSocketServiceImpl: service}
|
|
|
|
DeferCleanup(func() {
|
|
_ = manager.Stop()
|
|
_ = os.RemoveAll(tmpDir)
|
|
})
|
|
})
|
|
|
|
BeforeEach(func() {
|
|
// Clean up any connections from previous tests
|
|
testService.closeAllConnections()
|
|
})
|
|
|
|
Describe("Plugin Loading", func() {
|
|
It("should detect WebSocket capability", func() {
|
|
names := manager.PluginNames(string(CapabilityWebSocket))
|
|
Expect(names).To(ContainElement("test-websocket"))
|
|
})
|
|
|
|
It("should register WebSocket service for plugin", func() {
|
|
service := findWebSocketService(manager, "test-websocket")
|
|
Expect(service).ToNot(BeNil())
|
|
})
|
|
})
|
|
|
|
Describe("URL Validation", func() {
|
|
It("should reject invalid URL schemes", func() {
|
|
ctx := GinkgoT().Context()
|
|
_, err := testService.Connect(ctx, "http://example.com", nil, "test-conn")
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("invalid URL scheme"))
|
|
})
|
|
|
|
It("should reject disallowed hosts", func() {
|
|
ctx := GinkgoT().Context()
|
|
_, err := testService.Connect(ctx, "wss://evil.com/socket", nil, "test-conn")
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("not allowed"))
|
|
})
|
|
|
|
It("should allow hosts matching wildcard patterns", func() {
|
|
// test-websocket manifest allows *.example.com
|
|
// The pattern *.example.com matches any host ending with .example.com
|
|
ctx := context.Background()
|
|
allowed := testService.isHostAllowed("api.example.com")
|
|
Expect(allowed).To(BeTrue())
|
|
|
|
// Deep subdomains also match (ends with .example.com)
|
|
allowed = testService.isHostAllowed("sub.api.example.com")
|
|
Expect(allowed).To(BeTrue())
|
|
|
|
// But exact match without subdomain doesn't match *.example.com
|
|
allowed = testService.isHostAllowed("example.com")
|
|
Expect(allowed).To(BeFalse())
|
|
_ = ctx
|
|
})
|
|
|
|
It("should allow exact host matches", func() {
|
|
// test-websocket manifest allows echo.websocket.org
|
|
allowed := testService.isHostAllowed("echo.websocket.org")
|
|
Expect(allowed).To(BeTrue())
|
|
|
|
allowed = testService.isHostAllowed("other.org")
|
|
Expect(allowed).To(BeFalse())
|
|
})
|
|
|
|
It("should strip port before checking host", func() {
|
|
// Implementation strips port before matching against patterns
|
|
// test-websocket manifest has "localhost:*" which matches "localhost"
|
|
// after port stripping
|
|
// Note: The port wildcard pattern isn't actually implemented, but
|
|
// since port is stripped, "localhost:*" is compared against "localhost"
|
|
// which won't match. To make localhost work, we'd need exact "localhost"
|
|
// in the allowed hosts list.
|
|
|
|
// Testing that port is properly stripped
|
|
// The pattern "localhost:*" won't match "localhost" due to exact match
|
|
allowed := testService.isHostAllowed("localhost:8080")
|
|
Expect(allowed).To(BeFalse())
|
|
})
|
|
})
|
|
|
|
Describe("Connection Management", func() {
|
|
var wsServer *httptest.Server
|
|
var serverMessages []string
|
|
var serverMu sync.Mutex
|
|
|
|
BeforeEach(func() {
|
|
serverMessages = nil
|
|
|
|
upgrader := websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Read messages until connection closes
|
|
for {
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
serverMu.Lock()
|
|
serverMessages = append(serverMessages, string(msg))
|
|
serverMu.Unlock()
|
|
}
|
|
}))
|
|
|
|
// Add the server's host to allowed hosts for testing
|
|
// Since the implementation strips port before matching, we need to add
|
|
// the host without port
|
|
serverURL := strings.TrimPrefix(wsServer.URL, "http://")
|
|
hostOnly := serverURL
|
|
if idx := strings.LastIndex(serverURL, ":"); idx != -1 {
|
|
hostOnly = serverURL[:idx]
|
|
}
|
|
testService.requiredHosts = append(testService.requiredHosts, hostOnly)
|
|
})
|
|
|
|
AfterEach(func() {
|
|
testService.closeAllConnections()
|
|
if wsServer != nil {
|
|
wsServer.Close()
|
|
}
|
|
})
|
|
|
|
It("should connect to WebSocket server", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
connID, err := testService.Connect(ctx, wsURL, nil, "test-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(connID).To(Equal("test-conn"))
|
|
Expect(testService.getConnectionCount()).To(Equal(1))
|
|
})
|
|
|
|
It("should generate connection ID when not provided", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
connID, err := testService.Connect(ctx, wsURL, nil, "")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(connID).ToNot(BeEmpty())
|
|
})
|
|
|
|
It("should reject duplicate connection IDs", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
_, err := testService.Connect(ctx, wsURL, nil, "dup-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
_, err = testService.Connect(ctx, wsURL, nil, "dup-conn")
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("already exists"))
|
|
})
|
|
|
|
It("should send text messages", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
connID, err := testService.Connect(ctx, wsURL, nil, "send-text-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
err = testService.SendText(ctx, connID, "hello world")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Give server time to receive the message
|
|
Eventually(func() []string {
|
|
serverMu.Lock()
|
|
defer serverMu.Unlock()
|
|
return serverMessages
|
|
}).Should(ContainElement("hello world"))
|
|
})
|
|
|
|
It("should send binary messages", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
connID, err := testService.Connect(ctx, wsURL, nil, "send-binary-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
binaryData := []byte{0x00, 0x01, 0x02, 0x03}
|
|
err = testService.SendBinary(ctx, connID, binaryData)
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Give server time to receive the message
|
|
Eventually(func() []string {
|
|
serverMu.Lock()
|
|
defer serverMu.Unlock()
|
|
return serverMessages
|
|
}).Should(ContainElement(string(binaryData)))
|
|
})
|
|
|
|
It("should close connections", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
connID, err := testService.Connect(ctx, wsURL, nil, "close-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(testService.getConnectionCount()).To(Equal(1))
|
|
|
|
err = testService.CloseConnection(ctx, connID, 1000, "normal close")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(testService.getConnectionCount()).To(Equal(0))
|
|
})
|
|
|
|
It("should return error for non-existent connection", func() {
|
|
ctx := GinkgoT().Context()
|
|
err := testService.SendText(ctx, "non-existent", "message")
|
|
Expect(err).To(HaveOccurred())
|
|
Expect(err.Error()).To(ContainSubstring("not found"))
|
|
})
|
|
})
|
|
|
|
Describe("Plugin Callbacks", func() {
|
|
var wsServer *httptest.Server
|
|
var serverConn *websocket.Conn
|
|
var serverMessages []string
|
|
var serverBinaryMessages [][]byte
|
|
var serverMu sync.Mutex
|
|
|
|
BeforeEach(func() {
|
|
serverConn = nil
|
|
serverMessages = nil
|
|
serverBinaryMessages = nil
|
|
|
|
upgrader := websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
serverMu.Lock()
|
|
serverConn = conn
|
|
serverMu.Unlock()
|
|
|
|
// Read and store messages
|
|
for {
|
|
msgType, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
serverMu.Lock()
|
|
if msgType == websocket.BinaryMessage {
|
|
serverBinaryMessages = append(serverBinaryMessages, msg)
|
|
} else {
|
|
serverMessages = append(serverMessages, string(msg))
|
|
}
|
|
serverMu.Unlock()
|
|
}
|
|
}))
|
|
|
|
serverURL := strings.TrimPrefix(wsServer.URL, "http://")
|
|
hostOnly := serverURL
|
|
if idx := strings.LastIndex(serverURL, ":"); idx != -1 {
|
|
hostOnly = serverURL[:idx]
|
|
}
|
|
testService.requiredHosts = append(testService.requiredHosts, hostOnly)
|
|
})
|
|
|
|
AfterEach(func() {
|
|
testService.closeAllConnections()
|
|
if wsServer != nil {
|
|
wsServer.Close()
|
|
}
|
|
})
|
|
|
|
It("should invoke OnBinaryMessage callback when receiving binary", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
_, err := testService.Connect(ctx, wsURL, nil, "binary-cb-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Wait for server to have the connection
|
|
Eventually(func() *websocket.Conn {
|
|
serverMu.Lock()
|
|
defer serverMu.Unlock()
|
|
return serverConn
|
|
}).ShouldNot(BeNil())
|
|
|
|
// Send binary message from server to plugin
|
|
binaryData := []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
|
serverMu.Lock()
|
|
err = serverConn.WriteMessage(websocket.BinaryMessage, binaryData)
|
|
serverMu.Unlock()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Plugin echoes binary data back as a binary message
|
|
Eventually(func() [][]byte {
|
|
serverMu.Lock()
|
|
defer serverMu.Unlock()
|
|
return serverBinaryMessages
|
|
}).Should(ContainElement(binaryData))
|
|
})
|
|
|
|
It("should invoke OnClose callback when server closes connection", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
_, err := testService.Connect(ctx, wsURL, nil, "close-cb-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Wait for server to have the connection
|
|
Eventually(func() *websocket.Conn {
|
|
serverMu.Lock()
|
|
defer serverMu.Unlock()
|
|
return serverConn
|
|
}).ShouldNot(BeNil())
|
|
|
|
// Close from server side
|
|
serverMu.Lock()
|
|
_ = serverConn.WriteMessage(websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "goodbye"))
|
|
serverConn.Close()
|
|
serverMu.Unlock()
|
|
|
|
// Connection should be removed after close callback
|
|
Eventually(func() int {
|
|
return testService.getConnectionCount()
|
|
}).Should(Equal(0))
|
|
})
|
|
})
|
|
|
|
Describe("Plugin Host Function Calls", func() {
|
|
var wsServer *httptest.Server
|
|
var serverConn *websocket.Conn
|
|
var serverMessages []string
|
|
var serverMu sync.Mutex
|
|
|
|
BeforeEach(func() {
|
|
serverMessages = nil
|
|
serverConn = nil
|
|
|
|
upgrader := websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
wsServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
serverMu.Lock()
|
|
serverConn = conn
|
|
serverMu.Unlock()
|
|
|
|
// Read and store messages
|
|
for {
|
|
_, msg, err := conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
serverMu.Lock()
|
|
serverMessages = append(serverMessages, string(msg))
|
|
serverMu.Unlock()
|
|
}
|
|
}))
|
|
|
|
serverURL := strings.TrimPrefix(wsServer.URL, "http://")
|
|
hostOnly := serverURL
|
|
if idx := strings.LastIndex(serverURL, ":"); idx != -1 {
|
|
hostOnly = serverURL[:idx]
|
|
}
|
|
testService.requiredHosts = append(testService.requiredHosts, hostOnly)
|
|
})
|
|
|
|
AfterEach(func() {
|
|
testService.closeAllConnections()
|
|
if wsServer != nil {
|
|
wsServer.Close()
|
|
}
|
|
})
|
|
|
|
It("should allow plugin to send messages via host function", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
_, err := testService.Connect(ctx, wsURL, nil, "host-send-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Wait for server to have the connection
|
|
Eventually(func() *websocket.Conn {
|
|
serverMu.Lock()
|
|
defer serverMu.Unlock()
|
|
return serverConn
|
|
}).ShouldNot(BeNil())
|
|
|
|
// Server sends "echo" message to trigger plugin to echo back
|
|
serverMu.Lock()
|
|
err = serverConn.WriteMessage(websocket.TextMessage, []byte("echo"))
|
|
serverMu.Unlock()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Plugin should have echoed back via host function
|
|
Eventually(func() []string {
|
|
serverMu.Lock()
|
|
defer serverMu.Unlock()
|
|
return serverMessages
|
|
}).Should(ContainElement("echo:echo"))
|
|
})
|
|
|
|
It("should allow plugin to close connection via host function", func() {
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + strings.TrimPrefix(wsServer.URL, "http://")
|
|
_, err := testService.Connect(ctx, wsURL, nil, "host-close-conn")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(testService.getConnectionCount()).To(Equal(1))
|
|
|
|
// Wait for server to have the connection
|
|
Eventually(func() *websocket.Conn {
|
|
serverMu.Lock()
|
|
defer serverMu.Unlock()
|
|
return serverConn
|
|
}).ShouldNot(BeNil())
|
|
|
|
// Server sends "close" message to trigger plugin to close connection
|
|
serverMu.Lock()
|
|
err = serverConn.WriteMessage(websocket.TextMessage, []byte("close"))
|
|
serverMu.Unlock()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
|
|
// Connection should be closed by plugin
|
|
Eventually(func() int {
|
|
return testService.getConnectionCount()
|
|
}).Should(Equal(0))
|
|
})
|
|
})
|
|
|
|
Describe("Plugin Unload", func() {
|
|
It("should close all connections when plugin is unloaded", func() {
|
|
// Create a fresh server for this test
|
|
upgrader := websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
// Keep alive
|
|
for {
|
|
_, _, err := conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
}))
|
|
defer wsServer.Close()
|
|
|
|
serverURL := strings.TrimPrefix(wsServer.URL, "http://")
|
|
hostOnly := serverURL
|
|
if idx := strings.LastIndex(serverURL, ":"); idx != -1 {
|
|
hostOnly = serverURL[:idx]
|
|
}
|
|
testService.requiredHosts = append(testService.requiredHosts, hostOnly)
|
|
|
|
ctx := GinkgoT().Context()
|
|
wsURL := "ws://" + serverURL
|
|
|
|
// Create multiple connections
|
|
_, err := testService.Connect(ctx, wsURL, nil, "unload-conn-1")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
_, err = testService.Connect(ctx, wsURL, nil, "unload-conn-2")
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(testService.getConnectionCount()).To(Equal(2))
|
|
|
|
// Close the service (simulates plugin unload)
|
|
err = testService.Close()
|
|
Expect(err).ToNot(HaveOccurred())
|
|
Expect(testService.getConnectionCount()).To(Equal(0))
|
|
})
|
|
})
|
|
|
|
Describe("matchHostPattern", func() {
|
|
It("should match exact hosts", func() {
|
|
Expect(matchHostPattern("example.com", "example.com")).To(BeTrue())
|
|
Expect(matchHostPattern("example.com", "other.com")).To(BeFalse())
|
|
})
|
|
|
|
It("should match wildcard patterns", func() {
|
|
Expect(matchHostPattern("*.example.com", "api.example.com")).To(BeTrue())
|
|
Expect(matchHostPattern("*.example.com", "example.com")).To(BeFalse())
|
|
Expect(matchHostPattern("*.example.com", "deep.api.example.com")).To(BeTrue())
|
|
})
|
|
|
|
It("should match bare '*' as allow-all", func() {
|
|
Expect(matchHostPattern("*", "anything.example.com")).To(BeTrue())
|
|
Expect(matchHostPattern("*", "127.0.0.1")).To(BeTrue())
|
|
Expect(matchHostPattern("*", "::1")).To(BeTrue())
|
|
})
|
|
|
|
It("should not match partial patterns", func() {
|
|
Expect(matchHostPattern("*.example.com", "example.com.evil.org")).To(BeFalse())
|
|
})
|
|
})
|
|
})
|
|
|
|
// testableWebSocketService wraps webSocketServiceImpl with test helpers.
|
|
type testableWebSocketService struct {
|
|
*webSocketServiceImpl
|
|
}
|
|
|
|
func (t *testableWebSocketService) getConnectionCount() int {
|
|
t.mu.RLock()
|
|
defer t.mu.RUnlock()
|
|
return len(t.connections)
|
|
}
|
|
|
|
func (t *testableWebSocketService) closeAllConnections() {
|
|
t.mu.Lock()
|
|
conns := make(map[string]*wsConnection, len(t.connections))
|
|
maps.Copy(conns, t.connections)
|
|
t.connections = make(map[string]*wsConnection)
|
|
t.mu.Unlock()
|
|
|
|
for _, conn := range conns {
|
|
conn.closeMu.Lock()
|
|
conn.isClosed = true
|
|
conn.closeMu.Unlock()
|
|
_ = conn.conn.Close()
|
|
close(conn.done)
|
|
}
|
|
}
|
|
|
|
// findWebSocketService finds the WebSocket service from a plugin's closers.
|
|
func findWebSocketService(m *Manager, pluginName string) *webSocketServiceImpl {
|
|
m.mu.RLock()
|
|
instance, ok := m.plugins[pluginName]
|
|
m.mu.RUnlock()
|
|
if !ok {
|
|
return nil
|
|
}
|
|
for _, closer := range instance.closers {
|
|
if svc, ok := closer.(*webSocketServiceImpl); ok {
|
|
return svc
|
|
}
|
|
}
|
|
return nil
|
|
}
|