From d5b47383ae2f6e6acd93457f66bb436fe79beac1 Mon Sep 17 00:00:00 2001 From: Kartik Ohri Date: Sat, 26 Jul 2025 10:26:22 +0530 Subject: [PATCH] Associate API keys with players instead of users # Conflicts: # tests/mock_user_repo.go --- .../20250415111500_add_api_key_table.go | 8 +- model/api_key.go | 10 +- persistence/api_key_repository.go | 113 ++++++++---- persistence/api_key_repository_test.go | 167 ++++++++++-------- persistence/user_repository.go | 9 +- persistence/user_repository_test.go | 2 +- server/nativeapi/native_api.go | 37 ++++ server/subsonic/middlewares_test.go | 35 +++- tests/mock_apikey_repo.go | 9 +- tests/mock_data_store.go | 5 +- tests/mock_player_repo.go | 74 ++++++++ tests/mock_user_repo.go | 8 +- 12 files changed, 347 insertions(+), 130 deletions(-) create mode 100644 tests/mock_player_repo.go diff --git a/db/migrations/20250415111500_add_api_key_table.go b/db/migrations/20250415111500_add_api_key_table.go index d75ddd19f..a931d1827 100644 --- a/db/migrations/20250415111500_add_api_key_table.go +++ b/db/migrations/20250415111500_add_api_key_table.go @@ -15,18 +15,18 @@ func upAddApiKeyTable(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, ` create table if not exists api_key ( id text not null primary key, - user_id text not null, + player_id text not null, name text not null, key text not null unique, created_at datetime not null, - foreign key (user_id) - references user(id) + foreign key (player_id) + references player(id) on delete cascade ); create index if not exists api_key_key on api_key(key); - create index if not exists api_key_user_id on api_key(user_id); + create index if not exists api_key_player_id on api_key(player_id); `) return err } diff --git a/model/api_key.go b/model/api_key.go index 6d238f422..63486c2fc 100644 --- a/model/api_key.go +++ b/model/api_key.go @@ -6,10 +6,10 @@ import ( ) type APIKey struct { - ID string `structs:"id" json:"id"` - UserID string `structs:"user_id" json:"userId"` - Name string `structs:"name" json:"name"` - Key string `structs:"key" json:"key"` + ID string `structs:"id" json:"id"` + PlayerID string `structs:"player_id" json:"playerId"` + Name string `structs:"name" json:"name"` + Key string `structs:"key" json:"key"` CreatedAt time.Time `structs:"created_at" json:"createdAt"` } @@ -21,6 +21,6 @@ type APIKeyRepository interface { CountAll(...QueryOptions) (int64, error) Get(id string) (*APIKey, error) GetAll(options ...QueryOptions) (APIKeys, error) - Put(*APIKey) error FindByKey(key string) (*APIKey, error) + RefreshKey(id string) (string, error) } diff --git a/persistence/api_key_repository.go b/persistence/api_key_repository.go index 583578d3c..be8e9d715 100644 --- a/persistence/api_key_repository.go +++ b/persistence/api_key_repository.go @@ -29,16 +29,16 @@ func (r *apiKeyRepository) userFilter() Sqlizer { if user.IsAdmin { return And{} } - return Eq{"user_id": user.ID} + return Eq{"p.user_id": user.ID} } func (r *apiKeyRepository) CountAll(options ...model.QueryOptions) (int64, error) { - sq := Select().From(r.tableName).Where(r.userFilter()) + sq := r.selectAPIKey(options...).Where(r.userFilter()) return r.count(sq, options...) } func (r *apiKeyRepository) Get(id string) (*model.APIKey, error) { - sel := r.newSelect().Columns("*").Where(And{Eq{"id": id}}) + sel := r.selectAPIKey().Where(And{Eq{"ak.id": id}}) var res model.APIKey err := r.queryOne(sel, &res) if err != nil { @@ -47,8 +47,15 @@ func (r *apiKeyRepository) Get(id string) (*model.APIKey, error) { return &res, err } +func (r *apiKeyRepository) selectAPIKey(options ...model.QueryOptions) SelectBuilder { + return r.newSelect(options...). + From("api_key ak"). + LeftJoin("player p ON ak.player_id = p.id"). + Columns("ak.*") +} + func (r *apiKeyRepository) GetAll(options ...model.QueryOptions) (model.APIKeys, error) { - sel := r.newSelect(options...).Columns("*").Where(r.userFilter()) + sel := r.selectAPIKey().Where(r.userFilter()) res := model.APIKeys{} err := r.queryAll(sel, &res) if err != nil { @@ -57,32 +64,17 @@ func (r *apiKeyRepository) GetAll(options ...model.QueryOptions) (model.APIKeys, return res, err } -func (r *apiKeyRepository) Put(ak *model.APIKey) error { - if ak.ID == "" { - ak.ID = id.NewRandom() - } - ak.CreatedAt = time.Now() - values, err := toSQLArgs(*ak) - if err != nil { - return err - } - insert := Insert(r.tableName).SetMap(values) - _, err = r.executeSQL(insert) - return err -} - func (r *apiKeyRepository) Count(options ...rest.QueryOptions) (int64, error) { return r.CountAll(r.parseRestOptions(r.ctx, options...)) } func (r *apiKeyRepository) Read(id string) (interface{}, error) { - user := loggedUser(r.ctx) apiKey, err := r.Get(id) if err != nil { return nil, err } - if !user.IsAdmin && apiKey.UserID != user.ID { - return nil, rest.ErrPermissionDenied + if err := r.VerifyPlayerAccess(apiKey.PlayerID); err != nil { + return nil, err } return apiKey, err } @@ -101,14 +93,22 @@ func (r *apiKeyRepository) NewInstance() interface{} { func (r *apiKeyRepository) Save(entity interface{}) (string, error) { ak := entity.(*model.APIKey) - user := loggedUser(r.ctx) - ak.UserID = user.ID - // prefix API keys with nav_ - ak.Key = "nav_" + id.NewRandom() - err := r.Put(ak) + if err := r.VerifyPlayerAccess(ak.PlayerID); err != nil { + return "", err + } + + if ak.ID == "" { + ak.ID = id.NewRandom() + } + ak.Key = generateAPIKey() + ak.CreatedAt = time.Now() + values, err := toSQLArgs(*ak) if err != nil { return "", err } + + insert := Insert(r.tableName).SetMap(values) + _, err = r.executeSQL(insert) return ak.ID, err } @@ -118,12 +118,11 @@ func (r *apiKeyRepository) Update(id string, entity interface{}, _ ...string) er if err != nil { return err } - user := loggedUser(r.ctx) - if !user.IsAdmin && current.UserID != user.ID { - return rest.ErrPermissionDenied + + if err := r.VerifyPlayerAccess(current.PlayerID); err != nil { + return err } - // Only allow updating name update := Update(r.tableName). Set("name", ak.Name). Where(Eq{"id": id}) @@ -132,19 +131,20 @@ func (r *apiKeyRepository) Update(id string, entity interface{}, _ ...string) er } func (r *apiKeyRepository) Delete(id string) error { - user := loggedUser(r.ctx) apiKey, err := r.Get(id) if err != nil { return err } - if !user.IsAdmin && apiKey.UserID != user.ID { - return rest.ErrPermissionDenied + + if err := r.VerifyPlayerAccess(apiKey.PlayerID); err != nil { + return err } + return r.delete(Eq{"id": id}) } func (r *apiKeyRepository) FindByKey(key string) (*model.APIKey, error) { - sel := r.newSelect().Columns("*").Where(Eq{"key": key}) + sel := r.selectAPIKey().Where(And{Eq{"ak.key": key}}) var res model.APIKey err := r.queryOne(sel, &res) if err != nil { @@ -153,6 +153,51 @@ func (r *apiKeyRepository) FindByKey(key string) (*model.APIKey, error) { return &res, err } +func (r *apiKeyRepository) RefreshKey(id string) (string, error) { + apiKey, err := r.Get(id) + if err != nil { + return "", err + } + if err := r.VerifyPlayerAccess(apiKey.PlayerID); err != nil { + return "", err + } + + newKey := generateAPIKey() + update := Update(r.tableName). + Set("key", newKey). + Where(Eq{"id": id}) + _, err = r.executeSQL(update) + + if err != nil { + return "", err + } + + return newKey, nil +} + +func (r *apiKeyRepository) VerifyPlayerAccess(playerID string) error { + if playerID == "" { + return model.ErrNotFound + } + + playerRepo := NewPlayerRepository(r.ctx, r.db) + player, err := playerRepo.Get(playerID) + if err != nil { + return err + } + + user := loggedUser(r.ctx) + if !user.IsAdmin && player.UserId != user.ID { + return rest.ErrPermissionDenied + } + + return nil +} + +func generateAPIKey() string { + return "nav_" + id.NewRandom() +} + var _ model.APIKeyRepository = (*apiKeyRepository)(nil) var _ rest.Repository = (*apiKeyRepository)(nil) var _ rest.Persistable = (*apiKeyRepository)(nil) diff --git a/persistence/api_key_repository_test.go b/persistence/api_key_repository_test.go index 44ab6f0cc..d7df1504b 100644 --- a/persistence/api_key_repository_test.go +++ b/persistence/api_key_repository_test.go @@ -8,64 +8,49 @@ import ( "github.com/navidrome/navidrome/model/request" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/pocketbase/dbx" ) var _ = Describe("APIKeyRepository", func() { var repo model.APIKeyRepository + var playerRepo model.PlayerRepository + var database *dbx.DB + + var ( + adminPlayer = model.Player{ID: "1", Name: "NavidromeUI [Firefox/Linux]", UserAgent: "Firefox/Linux", UserId: adminUser.ID, Username: adminUser.UserName, Client: "NavidromeUI", IP: "127.0.0.1", ReportRealPath: true, ScrobbleEnabled: true} + regularPlayer = model.Player{ID: "3", Name: "NavidromeUI [Safari/macOS]", UserAgent: "Safari/macOS", UserId: regularUser.ID, Username: regularUser.UserName, Client: "NavidromeUI", ReportRealPath: true, ScrobbleEnabled: false} + + players = model.Players{adminPlayer, regularPlayer} + ) BeforeEach(func() { ctx := log.NewContext(context.TODO()) - ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true}) - repo = NewAPIKeyRepository(ctx, GetDBXBuilder()) - }) + ctx = request.WithUser(ctx, adminUser) + database = GetDBXBuilder() - Describe("Put", func() { - It("sets an ID if it is not set", func() { - apiKey := &model.APIKey{ - UserID: "userid", - Name: "Test API Key", - Key: "test-key", - } - - err := repo.Put(apiKey) - - Expect(err).ToNot(HaveOccurred()) - Expect(apiKey.ID).ToNot(BeEmpty()) - Expect(apiKey.CreatedAt).ToNot(BeZero()) - }) - - It("keeps existing values", func() { - apiKey := &model.APIKey{ - ID: "existing-id", - UserID: "userid", - Name: "Test API Key 2", - Key: "test-key-2", - } - - err := repo.Put(apiKey) - - Expect(err).ToNot(HaveOccurred()) - Expect(apiKey.ID).To(Equal("existing-id")) - Expect(apiKey.CreatedAt).ToNot(BeZero()) - }) + playerRepo = NewPlayerRepository(ctx, database) + for idx := range players { + err := playerRepo.Put(&players[idx]) + Expect(err).To(BeNil()) + } + repo = NewAPIKeyRepository(ctx, database) }) Describe("FindByKey", func() { It("returns the API key with matching key", func() { apiKey := &model.APIKey{ - UserID: "userid", - Name: "Unique API Key", - Key: "unique-test-key", + PlayerID: adminPlayer.ID, + Name: "Unique API Key", } - - err := repo.Put(apiKey) + apiKeyId, err := repo.Save(apiKey) + Expect(err).ToNot(HaveOccurred()) + apiKey, err = repo.Get(apiKeyId) Expect(err).ToNot(HaveOccurred()) - result, err := repo.FindByKey("unique-test-key") - + result, err := repo.FindByKey(apiKey.Key) Expect(err).ToNot(HaveOccurred()) Expect(result.ID).To(Equal(apiKey.ID)) - Expect(result.Key).To(Equal("unique-test-key")) + Expect(result.Key).To(Equal(apiKey.Key)) }) It("returns error when key not found", func() { @@ -77,7 +62,8 @@ var _ = Describe("APIKeyRepository", func() { Describe("Save", func() { It("creates a new API key with a generated key", func() { apiKey := &model.APIKey{ - Name: "Test API Key Save", + Name: "Test API Key Save", + PlayerID: adminPlayer.ID, } id, err := repo.Save(apiKey) @@ -85,25 +71,24 @@ var _ = Describe("APIKeyRepository", func() { Expect(err).ToNot(HaveOccurred()) Expect(id).ToNot(BeEmpty()) Expect(apiKey.Key).To(HavePrefix("nav_")) - Expect(apiKey.UserID).To(Equal("userid")) + Expect(apiKey.PlayerID).To(Equal(adminPlayer.ID)) + Expect(apiKey.Name).To(Equal("Test API Key Save")) }) }) Describe("Update", func() { It("only updates the name field", func() { apiKey := &model.APIKey{ - UserID: "userid", - Name: "Original Name", - Key: "test-key-for-update", + PlayerID: adminPlayer.ID, + Name: "Original Name", } - err := repo.Put(apiKey) + _, err := repo.Save(apiKey) Expect(err).ToNot(HaveOccurred()) updateKey := &model.APIKey{ - Name: "Updated Name", - Key: "should-not-change", - UserID: "2222", + Name: "Updated Name", + PlayerID: regularPlayer.ID, } err = repo.Update(apiKey.ID, updateKey) @@ -112,8 +97,8 @@ var _ = Describe("APIKeyRepository", func() { result, err := repo.Get(apiKey.ID) Expect(err).ToNot(HaveOccurred()) Expect(result.Name).To(Equal("Updated Name")) - Expect(result.Key).To(Equal("test-key-for-update")) - Expect(result.UserID).To(Equal("userid")) + Expect(result.Key).To(Equal(apiKey.Key)) + Expect(result.PlayerID).To(Equal(adminPlayer.ID)) }) It("returns error when attempting to update non-existent key", func() { @@ -125,12 +110,11 @@ var _ = Describe("APIKeyRepository", func() { Describe("Delete", func() { It("deletes an existing API key", func() { apiKey := &model.APIKey{ - UserID: "userid", - Name: "API Key to Delete", - Key: "key-to-delete", + PlayerID: adminPlayer.ID, + Name: "API Key to Delete", } - err := repo.Put(apiKey) + _, err := repo.Save(apiKey) Expect(err).ToNot(HaveOccurred()) err = repo.Delete(apiKey.ID) @@ -141,6 +125,53 @@ var _ = Describe("APIKeyRepository", func() { }) }) + Describe("RefreshKey", func() { + It("generates a new key for an existing API key", func() { + apiKey := &model.APIKey{ + PlayerID: adminPlayer.ID, + Name: "Test Refresh", + } + _, err := repo.Save(apiKey) + Expect(err).ToNot(HaveOccurred()) + + originalKey := apiKey.Key + + newKey, err := repo.RefreshKey(apiKey.ID) + + Expect(err).ToNot(HaveOccurred()) + Expect(newKey).ToNot(BeEmpty()) + Expect(newKey).ToNot(Equal(originalKey)) + Expect(newKey).To(HavePrefix("nav_")) + + refreshed, err := repo.Get(apiKey.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(refreshed.Key).To(Equal(newKey)) + }) + + It("returns an error for non-existent API key", func() { + _, err := repo.RefreshKey("non-existent-id") + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(rest.ErrNotFound)) + }) + + It("enforces user permissions", func() { + apiKey := &model.APIKey{ + PlayerID: adminPlayer.ID, + Name: "Test Permission", + } + _, err := repo.Save(apiKey) + Expect(err).ToNot(HaveOccurred()) + + nonAdminCtx := log.NewContext(context.TODO()) + nonAdminCtx = request.WithUser(nonAdminCtx, regularUser) + nonAdminRepo := NewAPIKeyRepository(nonAdminCtx, database) + + _, err = nonAdminRepo.RefreshKey(apiKey.ID) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(rest.ErrPermissionDenied)) + }) + }) + Describe("User permissions", func() { var nonAdminCtx context.Context var nonAdminRepo model.APIKeyRepository @@ -148,8 +179,8 @@ var _ = Describe("APIKeyRepository", func() { BeforeEach(func() { nonAdminCtx = log.NewContext(context.TODO()) - nonAdminCtx = context.WithValue(nonAdminCtx, "user", model.User{ID: "2222", UserName: "user", IsAdmin: false}) - nonAdminRepo = NewAPIKeyRepository(nonAdminCtx, GetDBXBuilder()) + nonAdminCtx = request.WithUser(nonAdminCtx, regularUser) + nonAdminRepo = NewAPIKeyRepository(nonAdminCtx, database) cleanupKeys := func(key string) { foundKey, err := repo.FindByKey(key) @@ -161,20 +192,18 @@ var _ = Describe("APIKeyRepository", func() { cleanupKeys("user-key") tmpAdminKey := &model.APIKey{ - UserID: "userid", - Name: "Admin's API Key", - Key: "admin-key", + PlayerID: adminPlayer.ID, + Name: "Admin's API Key", } - err := repo.Put(tmpAdminKey) + _, err := repo.Save(tmpAdminKey) Expect(err).ToNot(HaveOccurred()) adminKey = *tmpAdminKey userKey := &model.APIKey{ - UserID: "2222", - Name: "User's API Key", - Key: "user-key", + PlayerID: regularPlayer.ID, + Name: "User's API Key", } - err = repo.Put(userKey) + _, err = repo.Save(userKey) Expect(err).ToNot(HaveOccurred()) }) @@ -183,7 +212,7 @@ var _ = Describe("APIKeyRepository", func() { Expect(err).ToNot(HaveOccurred()) for _, key := range results { - Expect(key.UserID).To(Equal("2222")) + Expect(key.PlayerID).To(Equal(regularPlayer.ID)) } }) @@ -193,11 +222,11 @@ var _ = Describe("APIKeyRepository", func() { userIds := make(map[string]bool) for _, key := range results { - userIds[key.UserID] = true + userIds[key.PlayerID] = true } - Expect(userIds).To(HaveKey("userid")) - Expect(userIds).To(HaveKey("2222")) + Expect(userIds).To(HaveKey(adminPlayer.ID)) + Expect(userIds).To(HaveKey(regularPlayer.ID)) }) It("a user cannot view/delete/update another user's key", func() { diff --git a/persistence/user_repository.go b/persistence/user_repository.go index c675228d3..ee057e373 100644 --- a/persistence/user_repository.go +++ b/persistence/user_repository.go @@ -195,14 +195,17 @@ func (r *userRepository) FindByUsernameWithPassword(username string) (*model.Use func (r *userRepository) FindByAPIKey(key string) (*model.User, error) { // find the API key in the database + playerRepo := NewPlayerRepository(r.ctx, r.db) apiKeyRepo := NewAPIKeyRepository(r.ctx, r.db) apiKey, err := apiKeyRepo.FindByKey(key) if err != nil { return nil, err } - - // Then get the user associated with this API key - return r.Get(apiKey.UserID) + player, err := playerRepo.Get(apiKey.PlayerID) + if err != nil { + return nil, err + } + return r.Get(player.UserId) } func (r *userRepository) UpdateLastLoginAt(id string) error { diff --git a/persistence/user_repository_test.go b/persistence/user_repository_test.go index 7c0707ecd..615c336f5 100644 --- a/persistence/user_repository_test.go +++ b/persistence/user_repository_test.go @@ -212,7 +212,7 @@ var _ = Describe("UserRepository", func() { var existingUser *model.User BeforeEach(func() { existingUser = &model.User{ID: "1", UserName: "johndoe"} - repo = tests.CreateMockUserRepo() + repo = tests.CreateMockUserRepo(nil, nil) err := repo.Put(existingUser) Expect(err).ToNot(HaveOccurred()) }) diff --git a/server/nativeapi/native_api.go b/server/nativeapi/native_api.go index c6d6e9c64..cab45fd6a 100644 --- a/server/nativeapi/native_api.go +++ b/server/nativeapi/native_api.go @@ -3,6 +3,8 @@ package nativeapi import ( "context" "encoding/json" + "fmt" + "github.com/navidrome/navidrome/utils/req" "html" "net/http" "strconv" @@ -74,6 +76,7 @@ func (n *Router) routes() http.Handler { n.addUserLibraryRoute(r) n.RX(r, "/library", n.libs.NewRepository, true) }) + n.addRefreshApiKeyRoute(r) }) return r @@ -247,3 +250,37 @@ func adminOnlyMiddleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +func (n *Router) addRefreshApiKeyRoute(r chi.Router) { + r.With(server.URLParamsMiddleware).Post("/apikey/{id}/refresh", func(w http.ResponseWriter, r *http.Request) { + p := req.Params(r) + id, err := p.String(":id") + if err != nil { + msg := fmt.Sprintf("api key id could not be parsed: %s", id) + log.Warn(msg) + http.Error(w, "not found", http.StatusNotFound) + return + } + + repo := n.ds.APIKey(r.Context()) + _, err = repo.RefreshKey(id) + if err != nil { + log.Error(r.Context(), "error refreshing api key", "id", id, err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + updatedKey, err := repo.Get(id) + if err != nil { + log.Error(r.Context(), "error retrieving refreshed api key", "id", id, err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + err = rest.RespondWithJSON(w, http.StatusOK, updatedKey) + if err != nil { + log.Error(r.Context(), "error marshaling refreshed api key", "id", id, err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) +} diff --git a/server/subsonic/middlewares_test.go b/server/subsonic/middlewares_test.go index e9f1c0177..ee3879c09 100644 --- a/server/subsonic/middlewares_test.go +++ b/server/subsonic/middlewares_test.go @@ -308,6 +308,8 @@ var _ = Describe("Middlewares", func() { }) When("using api key authentication", func() { + var apiKey *model.APIKey + BeforeEach(func() { DeferCleanup(configtest.SetupConfig()) @@ -318,18 +320,35 @@ var _ = Describe("Middlewares", func() { } _ = ur.Put(user) - ar := ds.APIKey(context.TODO()) - apiKey := &model.APIKey{ - ID: "api-key-id", - UserID: user.ID, - Name: "API Key", - Key: "api-key", + pr := ds.Player(context.TODO()) + player := &model.Player{ + ID: "player1", + Name: "Test Player", + UserAgent: "Test/1.0", + UserId: user.ID, + Client: "test-client", + IP: "127.0.0.1", + LastSeen: time.Now(), + TranscodingId: "", + MaxBitRate: 320, + ReportRealPath: false, + ScrobbleEnabled: true, } - _ = ar.Put(apiKey) + _ = pr.Put(player) + + ar := ds.APIKey(context.TODO()) + newApiKey := &model.APIKey{ + ID: "api-key-id", + Name: "API Key", + PlayerID: player.ID, + } + apiKeyId, _ := ar.Save(newApiKey) + newApiKey, _ = ar.Get(apiKeyId) + apiKey = newApiKey }) It("passes authentication with correct api key", func() { - r := newGetRequest("apiKey=api-key") + r := newGetRequest("apiKey=" + apiKey.Key) cp := authenticate(ds)(next) cp.ServeHTTP(w, r) diff --git a/tests/mock_apikey_repo.go b/tests/mock_apikey_repo.go index 240281a13..b99f2316c 100644 --- a/tests/mock_apikey_repo.go +++ b/tests/mock_apikey_repo.go @@ -2,6 +2,7 @@ package tests import ( "github.com/navidrome/navidrome/model" + "github.com/navidrome/navidrome/model/id" "strings" ) @@ -24,12 +25,14 @@ func (m *MockedAPIKeyRepo) CountAll(_ ...model.QueryOptions) (int64, error) { return int64(len(m.Data)), nil } -func (m *MockedAPIKeyRepo) Put(apiKey *model.APIKey) error { +func (m *MockedAPIKeyRepo) Save(entity interface{}) (string, error) { if m.Error != nil { - return m.Error + return "", m.Error } + apiKey := entity.(*model.APIKey) + apiKey.Key = "nav_" + id.NewRandom() m.Data[strings.ToLower(apiKey.Key)] = apiKey - return nil + return apiKey.ID, nil } func (m *MockedAPIKeyRepo) FindByKey(key string) (*model.APIKey, error) { diff --git a/tests/mock_data_store.go b/tests/mock_data_store.go index 93f2d8a68..5945c7faa 100644 --- a/tests/mock_data_store.go +++ b/tests/mock_data_store.go @@ -170,8 +170,9 @@ func (db *MockDataStore) User(ctx context.Context) model.UserRepository { if db.RealDS != nil { db.MockedUser = db.RealDS.User(ctx) } else { + playerRepo := db.Player(ctx).(*MockedPlayerRepo) apiKeyRepo := db.APIKey(ctx).(*MockedAPIKeyRepo) - db.MockedUser = CreateMockUserRepo(apiKeyRepo) + db.MockedUser = CreateMockUserRepo(playerRepo, apiKeyRepo) } } return db.MockedUser @@ -193,7 +194,7 @@ func (db *MockDataStore) Player(ctx context.Context) model.PlayerRepository { if db.RealDS != nil { db.MockedPlayer = db.RealDS.Player(ctx) } else { - db.MockedPlayer = struct{ model.PlayerRepository }{} + db.MockedPlayer = CreateMockPlayerRepo() } } return db.MockedPlayer diff --git a/tests/mock_player_repo.go b/tests/mock_player_repo.go new file mode 100644 index 000000000..b84cc2663 --- /dev/null +++ b/tests/mock_player_repo.go @@ -0,0 +1,74 @@ +package tests + +import ( + "encoding/base64" + "github.com/navidrome/navidrome/model" +) + +func CreateMockPlayerRepo() *MockedPlayerRepo { + return &MockedPlayerRepo{ + Data: make(map[string]*model.Player), + } +} + +type MockedPlayerRepo struct { + model.PlayerRepository + Error error + Data map[string]*model.Player +} + +func (m *MockedPlayerRepo) Get(id string) (*model.Player, error) { + if m.Error != nil { + return nil, m.Error + } + + player, exists := m.Data[id] + if !exists { + return nil, model.ErrNotFound + } + return player, nil +} + +func (m *MockedPlayerRepo) FindMatch(userId, client, userAgent string) (*model.Player, error) { + if m.Error != nil { + return nil, m.Error + } + + for _, player := range m.Data { + if player.UserId == userId && player.Client == client && player.UserAgent == userAgent { + return player, nil + } + } + return nil, model.ErrNotFound +} + +func (m *MockedPlayerRepo) Put(p *model.Player) error { + if m.Error != nil { + return m.Error + } + + if p.ID == "" { + p.ID = base64.StdEncoding.EncodeToString([]byte(p.Name + "_" + p.UserId + "_" + p.Client)) + } + m.Data[p.ID] = p + return nil +} + +func (m *MockedPlayerRepo) CountAll(_ ...model.QueryOptions) (int64, error) { + if m.Error != nil { + return 0, m.Error + } + return int64(len(m.Data)), nil +} + +func (m *MockedPlayerRepo) CountByClient(_ ...model.QueryOptions) (map[string]int64, error) { + if m.Error != nil { + return nil, m.Error + } + + result := make(map[string]int64) + for _, player := range m.Data { + result[player.Client]++ + } + return result, nil +} diff --git a/tests/mock_user_repo.go b/tests/mock_user_repo.go index 5089dd18e..bc88fc04c 100644 --- a/tests/mock_user_repo.go +++ b/tests/mock_user_repo.go @@ -30,6 +30,7 @@ type MockedUserRepo struct { Data map[string]*model.User UserLibraries map[string][]int // userID -> libraryIDs APIKeyRepo *MockedAPIKeyRepo + PlayerRepo *MockedPlayerRepo } func (u *MockedUserRepo) CountAll(_ ...model.QueryOptions) (int64, error) { @@ -142,8 +143,13 @@ func (u *MockedUserRepo) FindByAPIKey(key string) (*model.User, error) { return nil, err } + player, err := u.PlayerRepo.Get(apiKey.PlayerID) + if err != nil { + return nil, err + } + for _, usr := range u.Data { - if usr.ID == apiKey.UserID { + if usr.ID == player.UserId { return usr, nil } }