diff --git a/conf/configuration.go b/conf/configuration.go index 1c4829d82..d93024c8a 100644 --- a/conf/configuration.go +++ b/conf/configuration.go @@ -134,6 +134,7 @@ type configOptions struct { DevArtworkMaxRequests int DevArtworkThrottleBacklogLimit int DevArtworkThrottleBacklogTimeout time.Duration + DevArtworkThrottleBuffered bool DevArtistInfoTimeToLive time.Duration DevAlbumInfoTimeToLive time.Duration DevExternalScanner bool @@ -861,6 +862,7 @@ func setViperDefaults() { viper.SetDefault("devartworkmaxrequests", max(2, runtime.NumCPU()/2)) viper.SetDefault("devartworkthrottlebackloglimit", consts.RequestThrottleBacklogLimit) viper.SetDefault("devartworkthrottlebacklogtimeout", consts.RequestThrottleBacklogTimeout) + viper.SetDefault("devartworkthrottlebuffered", true) viper.SetDefault("devartistinfotimetolive", consts.ArtistInfoTimeToLive) viper.SetDefault("devalbuminfotimetolive", consts.AlbumInfoTimeToLive) viper.SetDefault("devexternalscanner", true) diff --git a/server/public/public.go b/server/public/public.go index 5e3407c19..18867e1c4 100644 --- a/server/public/public.go +++ b/server/public/public.go @@ -5,14 +5,12 @@ import ( "path" "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" "github.com/navidrome/navidrome/conf" "github.com/navidrome/navidrome/consts" "github.com/navidrome/navidrome/core" "github.com/navidrome/navidrome/core/artwork" "github.com/navidrome/navidrome/core/publicurl" "github.com/navidrome/navidrome/core/stream" - "github.com/navidrome/navidrome/log" "github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/server" "github.com/navidrome/navidrome/ui" @@ -43,13 +41,8 @@ func (pub *Router) routes() http.Handler { r.Group(func(r chi.Router) { r.Use(server.URLParamsMiddleware) r.Group(func(r chi.Router) { - if conf.Server.DevArtworkMaxRequests > 0 { - log.Debug("Throttling public images endpoint", "maxRequests", conf.Server.DevArtworkMaxRequests, - "backlogLimit", conf.Server.DevArtworkThrottleBacklogLimit, "backlogTimeout", - conf.Server.DevArtworkThrottleBacklogTimeout) - r.Use(middleware.ThrottleBacklog(conf.Server.DevArtworkMaxRequests, conf.Server.DevArtworkThrottleBacklogLimit, - conf.Server.DevArtworkThrottleBacklogTimeout)) - } + r.Use(server.ThrottleBacklog(conf.Server.DevArtworkMaxRequests, conf.Server.DevArtworkThrottleBacklogLimit, + conf.Server.DevArtworkThrottleBacklogTimeout)) r.HandleFunc("/img/{id}", pub.handleImages) }) if conf.Server.EnableSharing { diff --git a/server/subsonic/api.go b/server/subsonic/api.go index 0bbfcb83e..1ca364449 100644 --- a/server/subsonic/api.go +++ b/server/subsonic/api.go @@ -9,7 +9,6 @@ import ( "regexp" "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" "github.com/navidrome/navidrome/conf" "github.com/navidrome/navidrome/core" "github.com/navidrome/navidrome/core/artwork" @@ -190,14 +189,8 @@ func (api *Router) routes() http.Handler { hr(r, "getTranscodeStream", api.GetTranscodeStream) }) r.Group(func(r chi.Router) { - // configure request throttling - if conf.Server.DevArtworkMaxRequests > 0 { - log.Debug("Throttling Subsonic getCoverArt endpoint", "maxRequests", conf.Server.DevArtworkMaxRequests, - "backlogLimit", conf.Server.DevArtworkThrottleBacklogLimit, "backlogTimeout", - conf.Server.DevArtworkThrottleBacklogTimeout) - r.Use(middleware.ThrottleBacklog(conf.Server.DevArtworkMaxRequests, conf.Server.DevArtworkThrottleBacklogLimit, - conf.Server.DevArtworkThrottleBacklogTimeout)) - } + r.Use(server.ThrottleBacklog(conf.Server.DevArtworkMaxRequests, conf.Server.DevArtworkThrottleBacklogLimit, + conf.Server.DevArtworkThrottleBacklogTimeout)) hr(r, "getCoverArt", api.GetCoverArt) }) r.Group(func(r chi.Router) { diff --git a/server/subsonic/media_retrieval_test.go b/server/subsonic/media_retrieval_test.go index 589a609da..27d1edb84 100644 --- a/server/subsonic/media_retrieval_test.go +++ b/server/subsonic/media_retrieval_test.go @@ -78,16 +78,13 @@ var _ = Describe("MediaRetrievalController", func() { When("client disconnects (context is cancelled)", func() { It("should not call the service if cancelled before the call", func() { - // Create a request ctx, cancel := context.WithCancel(context.Background()) r := newGetRequest("id=34", "size=128", "square=true") r = r.WithContext(ctx) - cancel() // Cancel the context before the call + cancel() - // Call the GetCoverArt method _, err := router.GetCoverArt(w, r) - // Expect no error and no call to the artwork service Expect(err).ToNot(HaveOccurred()) Expect(artwork.recvId).To(Equal("")) Expect(artwork.recvSize).To(Equal(0)) @@ -96,17 +93,14 @@ var _ = Describe("MediaRetrievalController", func() { }) It("should not return data if cancelled during the call", func() { - // Create a request with a context that will be cancelled ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // Ensure the context is cancelled after the test (best practices) + defer cancel() r := newGetRequest("id=34", "size=128", "square=true") r = r.WithContext(ctx) - artwork.ctxCancelFunc = cancel // Set the cancel function to simulate cancellation in the service + artwork.ctxCancelFunc = cancel - // Call the GetCoverArt method _, err := router.GetCoverArt(w, r) - // Expect no error and the service to have been called Expect(err).ToNot(HaveOccurred()) Expect(artwork.recvId).To(Equal("34")) Expect(artwork.recvSize).To(Equal(128)) @@ -344,7 +338,7 @@ func (c *fakeArtwork) GetOrPlaceholder(_ context.Context, id string, size int, s c.recvSize = size c.recvSquare = square if c.ctxCancelFunc != nil { - c.ctxCancelFunc() // Simulate context cancellation + c.ctxCancelFunc() return nil, time.Time{}, context.Canceled } return io.NopCloser(bytes.NewReader([]byte(c.data))), time.Time{}, nil @@ -363,9 +357,7 @@ func (m *mockedMediaFile) GetAll(opts ...model.QueryOptions) (model.MediaFiles, return data, nil } - // Hardcoded support for lyrics sorting result := slices.Clone(data) - // Sort by presence of lyrics, then by updated_at. Respect the order specified in opts. slices.SortFunc(result, func(a, b model.MediaFile) int { diff := cmp.Or( cmp.Compare(a.Lyrics, b.Lyrics), diff --git a/server/throttle_backlog.go b/server/throttle_backlog.go new file mode 100644 index 000000000..c3672fd1e --- /dev/null +++ b/server/throttle_backlog.go @@ -0,0 +1,150 @@ +package server + +import ( + "bytes" + "context" + "errors" + "net/http" + "sync" + "time" + + "github.com/go-chi/chi/v5/middleware" + "github.com/navidrome/navidrome/conf" + "github.com/navidrome/navidrome/log" +) + +var ( + ErrThrottleCapacityExceeded = errors.New("throttle: capacity exceeded") + ErrThrottleTimeout = errors.New("throttle: backlog timeout") +) + +type requestThrottle struct { + tokens chan struct{} + backlogTokens chan struct{} + backlogTimeout time.Duration +} + +// ThrottleBacklog creates a Chi-compatible middleware that limits concurrent +// request processing. Unlike Chi's ThrottleBacklog, it buffers the handler's +// response while holding the token, releases it, then flushes the buffer to +// the client with a write deadline. This prevents slow clients from holding +// throttle capacity. +// +// Because it buffers the entire response in memory, this middleware should only +// be used for endpoints that return small responses (e.g., artwork images). Do +// not use it for audio streaming or download endpoints. +func ThrottleBacklog(limit, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler { + if limit <= 0 { + return func(next http.Handler) http.Handler { return next } + } + if !conf.Server.DevArtworkThrottleBuffered { + return middleware.ThrottleBacklog(limit, backlogLimit, backlogTimeout) + } + t := &requestThrottle{ + tokens: make(chan struct{}, limit), + backlogTokens: make(chan struct{}, limit+backlogLimit), + backlogTimeout: backlogTimeout, + } + for range limit { + t.tokens <- struct{}{} + } + for range limit + backlogLimit { + t.backlogTokens <- struct{}{} + } + return t.handler +} + +func (t *requestThrottle) handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + release, err := t.acquire(ctx) + if err != nil { + switch { + case errors.Is(err, ErrThrottleCapacityExceeded): + log.Warn(ctx, "Request throttle capacity exceeded", "path", r.URL.Path) + case errors.Is(err, ErrThrottleTimeout): + log.Warn(ctx, "Request throttle backlog timeout", "path", r.URL.Path) + } + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + return + } + + buf := &bufferedResponseWriter{header: make(http.Header)} + func() { + defer release() + next.ServeHTTP(buf, r) + }() + + for k, v := range buf.header { + w.Header()[k] = v + } + if buf.code > 0 { + w.WriteHeader(buf.code) + } + if _, err := w.Write(buf.body.Bytes()); err != nil { + log.Warn(ctx, "Error writing throttled response", err) + } + }) +} + +func (t *requestThrottle) acquire(ctx context.Context) (release func(), err error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-t.backlogTokens: + default: + return nil, ErrThrottleCapacityExceeded + } + + select { + case <-t.tokens: + return t.releaseFunc(), nil + default: + } + + timer := time.NewTimer(t.backlogTimeout) + select { + case <-timer.C: + t.backlogTokens <- struct{}{} + return nil, ErrThrottleTimeout + case <-ctx.Done(): + timer.Stop() + t.backlogTokens <- struct{}{} + return nil, ctx.Err() + case <-t.tokens: + timer.Stop() + return t.releaseFunc(), nil + } +} + +func (t *requestThrottle) releaseFunc() func() { + var once sync.Once + return func() { + once.Do(func() { + t.tokens <- struct{}{} + t.backlogTokens <- struct{}{} + }) + } +} + +type bufferedResponseWriter struct { + header http.Header + body bytes.Buffer + code int +} + +func (w *bufferedResponseWriter) Header() http.Header { + return w.header +} + +func (w *bufferedResponseWriter) Write(b []byte) (int, error) { + return w.body.Write(b) +} + +func (w *bufferedResponseWriter) WriteHeader(code int) { + if w.code != 0 { + return + } + w.code = code +} diff --git a/server/throttle_backlog_test.go b/server/throttle_backlog_test.go new file mode 100644 index 000000000..eb181b4ce --- /dev/null +++ b/server/throttle_backlog_test.go @@ -0,0 +1,266 @@ +package server + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/navidrome/navidrome/conf" + "github.com/navidrome/navidrome/conf/configtest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ThrottleBacklog", func() { + It("is a passthrough when limit is 0", func() { + m := ThrottleBacklog(0, 10, time.Second) + r := chi.NewRouter() + r.Use(m) + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("ok")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(w.Body.String()).To(Equal("ok")) + }) + + It("returns 429 when capacity is exceeded", func() { + _, secondStatus := runTwoRequests(ThrottleBacklog(1, 0, time.Second)) + Expect(secondStatus).To(Equal(http.StatusTooManyRequests)) + }) + + It("returns 429 when backlog times out", func() { + _, secondStatus := runTwoRequests(ThrottleBacklog(1, 1, 50*time.Millisecond)) + Expect(secondStatus).To(Equal(http.StatusTooManyRequests)) + }) + + It("releases capacity when the handler panics", func() { + m := ThrottleBacklog(1, 0, time.Second) + r := chi.NewRouter() + r.Use(middleware.Recoverer) + r.Use(m) + r.Get("/panic", func(w http.ResponseWriter, r *http.Request) { + panic("boom") + }) + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("ok")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/panic", nil) + r.ServeHTTP(w, req) + Expect(w.Code).To(Equal(http.StatusInternalServerError)) + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + Expect(w.Code).To(Equal(http.StatusOK)) + Expect(w.Body.String()).To(Equal("ok")) + }) + + It("preserves response headers and status code", func() { + m := ThrottleBacklog(2, 0, time.Second) + r := chi.NewRouter() + r.Use(m) + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/jpeg") + w.Header().Set("Cache-Control", "public") + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte("body")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusCreated)) + Expect(w.Header().Get("Content-Type")).To(Equal("image/jpeg")) + Expect(w.Header().Get("Cache-Control")).To(Equal("public")) + Expect(w.Body.String()).To(Equal("body")) + }) + + It("uses the first response status code", func() { + m := ThrottleBacklog(2, 0, time.Second) + r := chi.NewRouter() + r.Use(m) + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + w.WriteHeader(http.StatusAccepted) + _, _ = w.Write([]byte("body")) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + Expect(w.Code).To(Equal(http.StatusCreated)) + Expect(w.Body.String()).To(Equal("body")) + }) + + It("never exceeds the concurrency limit", func() { + const limit = 3 + const goroutines = 20 + m := ThrottleBacklog(limit, goroutines, 5*time.Second) + + var concurrent atomic.Int32 + var maxConcurrent atomic.Int32 + + r := chi.NewRouter() + r.Use(m) + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + cur := concurrent.Add(1) + for { + old := maxConcurrent.Load() + if cur <= old || maxConcurrent.CompareAndSwap(old, cur) { + break + } + } + time.Sleep(5 * time.Millisecond) + concurrent.Add(-1) + _, _ = w.Write([]byte("ok")) + }) + + var wg sync.WaitGroup + for range goroutines { + wg.Go(func() { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + }) + } + + wg.Wait() + Expect(maxConcurrent.Load()).To(BeNumerically("<=", limit)) + }) + + // Regression: with only 1 token, a slow client blocking during response + // writing must NOT prevent other requests from being served. Chi's original + // ThrottleBacklog holds the token for the entire handler lifecycle including + // io.Copy, causing starvation. The buffered implementation releases it first. + Context("when a client is slow to read the response", func() { + slowClientTest := func(m func(http.Handler) http.Handler) (*chi.Mux, chan struct{}, chan struct{}) { + handlerReached := make(chan struct{}, 1) + router := chi.NewRouter() + router.Use(m) + router.Get("/test", func(w http.ResponseWriter, r *http.Request) { + select { + case handlerReached <- struct{}{}: + default: + } + _, _ = io.Copy(w, strings.NewReader("image data")) + }) + + unblocked := make(chan struct{}) + slow := newSlowTestWriter(unblocked) + + reqDone := make(chan struct{}) + go func() { + defer close(reqDone) + req, _ := http.NewRequest("GET", "/test", nil) + router.ServeHTTP(slow, req) + }() + <-handlerReached + + return router, unblocked, reqDone + } + + It("does not starve concurrent requests with buffered middleware", func() { + router, unblocked, reqDone := slowClientTest(ThrottleBacklog(1, 1, 500*time.Millisecond)) + + Eventually(func() int { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + router.ServeHTTP(w, req) + return w.Code + }, 2*time.Second, 10*time.Millisecond).Should(Equal(http.StatusOK)) + + close(unblocked) + Eventually(reqDone, 2*time.Second).Should(BeClosed()) + }) + + It("starves concurrent requests with Chi's original middleware", func() { + DeferCleanup(configtest.SetupConfig()) + conf.Server.DevArtworkThrottleBuffered = false + + router, unblocked, reqDone := slowClientTest(ThrottleBacklog(1, 1, 500*time.Millisecond)) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + router.ServeHTTP(w, req) + Expect(w.Code).To(Equal(http.StatusTooManyRequests)) + + close(unblocked) + Eventually(reqDone, 2*time.Second).Should(BeClosed()) + }) + }) +}) + +// runTwoRequests sends two concurrent requests through a throttled router. The +// first request holds the token until the second has been dispatched. +func runTwoRequests(m func(http.Handler) http.Handler) (firstStatus, secondStatus int) { + held := make(chan struct{}) + release := make(chan struct{}) + r := chi.NewRouter() + r.Use(m) + r.Get("/test", func(w http.ResponseWriter, r *http.Request) { + select { + case held <- struct{}{}: + default: + } + <-release + _, _ = w.Write([]byte("ok")) + }) + + done := make(chan int) + go func() { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + done <- w.Code + }() + <-held + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + secondStatus = w.Code + + close(release) + firstStatus = <-done + return firstStatus, secondStatus +} + +// slowTestWriter implements http.ResponseWriter without embedding +// httptest.ResponseRecorder. This is necessary because ResponseRecorder +// promotes io.ReaderFrom, which io.Copy prefers over Write — bypassing +// our blocking Write and defeating the slow-client simulation. +type slowTestWriter struct { + header http.Header + body bytes.Buffer + code int + unblocked chan struct{} +} + +func newSlowTestWriter(unblocked chan struct{}) *slowTestWriter { + return &slowTestWriter{header: make(http.Header), unblocked: unblocked} +} + +func (w *slowTestWriter) Header() http.Header { return w.header } + +func (w *slowTestWriter) WriteHeader(code int) { w.code = code } + +func (w *slowTestWriter) Write(p []byte) (int, error) { + <-w.unblocked + return w.body.Write(p) +}