Associate API keys with players instead of users

# Conflicts:
#	tests/mock_user_repo.go
This commit is contained in:
Kartik Ohri 2025-07-26 10:26:22 +05:30
parent c655a91a5d
commit d5b47383ae
12 changed files with 347 additions and 130 deletions

View File

@ -15,18 +15,18 @@ func upAddApiKeyTable(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, ` _, err := tx.ExecContext(ctx, `
create table if not exists api_key ( create table if not exists api_key (
id text not null primary key, id text not null primary key,
user_id text not null, player_id text not null,
name text not null, name text not null,
key text not null unique, key text not null unique,
created_at datetime not null, created_at datetime not null,
foreign key (user_id) foreign key (player_id)
references user(id) references player(id)
on delete cascade on delete cascade
); );
create index if not exists api_key_key on api_key(key); create index if not exists api_key_key on api_key(key);
create index if not exists api_key_user_id on api_key(user_id); create index if not exists api_key_player_id on api_key(player_id);
`) `)
return err return err
} }

View File

@ -6,10 +6,10 @@ import (
) )
type APIKey struct { type APIKey struct {
ID string `structs:"id" json:"id"` ID string `structs:"id" json:"id"`
UserID string `structs:"user_id" json:"userId"` PlayerID string `structs:"player_id" json:"playerId"`
Name string `structs:"name" json:"name"` Name string `structs:"name" json:"name"`
Key string `structs:"key" json:"key"` Key string `structs:"key" json:"key"`
CreatedAt time.Time `structs:"created_at" json:"createdAt"` CreatedAt time.Time `structs:"created_at" json:"createdAt"`
} }
@ -21,6 +21,6 @@ type APIKeyRepository interface {
CountAll(...QueryOptions) (int64, error) CountAll(...QueryOptions) (int64, error)
Get(id string) (*APIKey, error) Get(id string) (*APIKey, error)
GetAll(options ...QueryOptions) (APIKeys, error) GetAll(options ...QueryOptions) (APIKeys, error)
Put(*APIKey) error
FindByKey(key string) (*APIKey, error) FindByKey(key string) (*APIKey, error)
RefreshKey(id string) (string, error)
} }

View File

@ -29,16 +29,16 @@ func (r *apiKeyRepository) userFilter() Sqlizer {
if user.IsAdmin { if user.IsAdmin {
return And{} return And{}
} }
return Eq{"user_id": user.ID} return Eq{"p.user_id": user.ID}
} }
func (r *apiKeyRepository) CountAll(options ...model.QueryOptions) (int64, error) { func (r *apiKeyRepository) CountAll(options ...model.QueryOptions) (int64, error) {
sq := Select().From(r.tableName).Where(r.userFilter()) sq := r.selectAPIKey(options...).Where(r.userFilter())
return r.count(sq, options...) return r.count(sq, options...)
} }
func (r *apiKeyRepository) Get(id string) (*model.APIKey, error) { func (r *apiKeyRepository) Get(id string) (*model.APIKey, error) {
sel := r.newSelect().Columns("*").Where(And{Eq{"id": id}}) sel := r.selectAPIKey().Where(And{Eq{"ak.id": id}})
var res model.APIKey var res model.APIKey
err := r.queryOne(sel, &res) err := r.queryOne(sel, &res)
if err != nil { if err != nil {
@ -47,8 +47,15 @@ func (r *apiKeyRepository) Get(id string) (*model.APIKey, error) {
return &res, err return &res, err
} }
func (r *apiKeyRepository) selectAPIKey(options ...model.QueryOptions) SelectBuilder {
return r.newSelect(options...).
From("api_key ak").
LeftJoin("player p ON ak.player_id = p.id").
Columns("ak.*")
}
func (r *apiKeyRepository) GetAll(options ...model.QueryOptions) (model.APIKeys, error) { func (r *apiKeyRepository) GetAll(options ...model.QueryOptions) (model.APIKeys, error) {
sel := r.newSelect(options...).Columns("*").Where(r.userFilter()) sel := r.selectAPIKey().Where(r.userFilter())
res := model.APIKeys{} res := model.APIKeys{}
err := r.queryAll(sel, &res) err := r.queryAll(sel, &res)
if err != nil { if err != nil {
@ -57,32 +64,17 @@ func (r *apiKeyRepository) GetAll(options ...model.QueryOptions) (model.APIKeys,
return res, err return res, err
} }
func (r *apiKeyRepository) Put(ak *model.APIKey) error {
if ak.ID == "" {
ak.ID = id.NewRandom()
}
ak.CreatedAt = time.Now()
values, err := toSQLArgs(*ak)
if err != nil {
return err
}
insert := Insert(r.tableName).SetMap(values)
_, err = r.executeSQL(insert)
return err
}
func (r *apiKeyRepository) Count(options ...rest.QueryOptions) (int64, error) { func (r *apiKeyRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...)) return r.CountAll(r.parseRestOptions(r.ctx, options...))
} }
func (r *apiKeyRepository) Read(id string) (interface{}, error) { func (r *apiKeyRepository) Read(id string) (interface{}, error) {
user := loggedUser(r.ctx)
apiKey, err := r.Get(id) apiKey, err := r.Get(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.IsAdmin && apiKey.UserID != user.ID { if err := r.VerifyPlayerAccess(apiKey.PlayerID); err != nil {
return nil, rest.ErrPermissionDenied return nil, err
} }
return apiKey, err return apiKey, err
} }
@ -101,14 +93,22 @@ func (r *apiKeyRepository) NewInstance() interface{} {
func (r *apiKeyRepository) Save(entity interface{}) (string, error) { func (r *apiKeyRepository) Save(entity interface{}) (string, error) {
ak := entity.(*model.APIKey) ak := entity.(*model.APIKey)
user := loggedUser(r.ctx) if err := r.VerifyPlayerAccess(ak.PlayerID); err != nil {
ak.UserID = user.ID return "", err
// prefix API keys with nav_ }
ak.Key = "nav_" + id.NewRandom()
err := r.Put(ak) if ak.ID == "" {
ak.ID = id.NewRandom()
}
ak.Key = generateAPIKey()
ak.CreatedAt = time.Now()
values, err := toSQLArgs(*ak)
if err != nil { if err != nil {
return "", err return "", err
} }
insert := Insert(r.tableName).SetMap(values)
_, err = r.executeSQL(insert)
return ak.ID, err return ak.ID, err
} }
@ -118,12 +118,11 @@ func (r *apiKeyRepository) Update(id string, entity interface{}, _ ...string) er
if err != nil { if err != nil {
return err return err
} }
user := loggedUser(r.ctx)
if !user.IsAdmin && current.UserID != user.ID { if err := r.VerifyPlayerAccess(current.PlayerID); err != nil {
return rest.ErrPermissionDenied return err
} }
// Only allow updating name
update := Update(r.tableName). update := Update(r.tableName).
Set("name", ak.Name). Set("name", ak.Name).
Where(Eq{"id": id}) Where(Eq{"id": id})
@ -132,19 +131,20 @@ func (r *apiKeyRepository) Update(id string, entity interface{}, _ ...string) er
} }
func (r *apiKeyRepository) Delete(id string) error { func (r *apiKeyRepository) Delete(id string) error {
user := loggedUser(r.ctx)
apiKey, err := r.Get(id) apiKey, err := r.Get(id)
if err != nil { if err != nil {
return err return err
} }
if !user.IsAdmin && apiKey.UserID != user.ID {
return rest.ErrPermissionDenied if err := r.VerifyPlayerAccess(apiKey.PlayerID); err != nil {
return err
} }
return r.delete(Eq{"id": id}) return r.delete(Eq{"id": id})
} }
func (r *apiKeyRepository) FindByKey(key string) (*model.APIKey, error) { func (r *apiKeyRepository) FindByKey(key string) (*model.APIKey, error) {
sel := r.newSelect().Columns("*").Where(Eq{"key": key}) sel := r.selectAPIKey().Where(And{Eq{"ak.key": key}})
var res model.APIKey var res model.APIKey
err := r.queryOne(sel, &res) err := r.queryOne(sel, &res)
if err != nil { if err != nil {
@ -153,6 +153,51 @@ func (r *apiKeyRepository) FindByKey(key string) (*model.APIKey, error) {
return &res, err return &res, err
} }
func (r *apiKeyRepository) RefreshKey(id string) (string, error) {
apiKey, err := r.Get(id)
if err != nil {
return "", err
}
if err := r.VerifyPlayerAccess(apiKey.PlayerID); err != nil {
return "", err
}
newKey := generateAPIKey()
update := Update(r.tableName).
Set("key", newKey).
Where(Eq{"id": id})
_, err = r.executeSQL(update)
if err != nil {
return "", err
}
return newKey, nil
}
func (r *apiKeyRepository) VerifyPlayerAccess(playerID string) error {
if playerID == "" {
return model.ErrNotFound
}
playerRepo := NewPlayerRepository(r.ctx, r.db)
player, err := playerRepo.Get(playerID)
if err != nil {
return err
}
user := loggedUser(r.ctx)
if !user.IsAdmin && player.UserId != user.ID {
return rest.ErrPermissionDenied
}
return nil
}
func generateAPIKey() string {
return "nav_" + id.NewRandom()
}
var _ model.APIKeyRepository = (*apiKeyRepository)(nil) var _ model.APIKeyRepository = (*apiKeyRepository)(nil)
var _ rest.Repository = (*apiKeyRepository)(nil) var _ rest.Repository = (*apiKeyRepository)(nil)
var _ rest.Persistable = (*apiKeyRepository)(nil) var _ rest.Persistable = (*apiKeyRepository)(nil)

View File

@ -8,64 +8,49 @@ import (
"github.com/navidrome/navidrome/model/request" "github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
) )
var _ = Describe("APIKeyRepository", func() { var _ = Describe("APIKeyRepository", func() {
var repo model.APIKeyRepository var repo model.APIKeyRepository
var playerRepo model.PlayerRepository
var database *dbx.DB
var (
adminPlayer = model.Player{ID: "1", Name: "NavidromeUI [Firefox/Linux]", UserAgent: "Firefox/Linux", UserId: adminUser.ID, Username: adminUser.UserName, Client: "NavidromeUI", IP: "127.0.0.1", ReportRealPath: true, ScrobbleEnabled: true}
regularPlayer = model.Player{ID: "3", Name: "NavidromeUI [Safari/macOS]", UserAgent: "Safari/macOS", UserId: regularUser.ID, Username: regularUser.UserName, Client: "NavidromeUI", ReportRealPath: true, ScrobbleEnabled: false}
players = model.Players{adminPlayer, regularPlayer}
)
BeforeEach(func() { BeforeEach(func() {
ctx := log.NewContext(context.TODO()) ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true}) ctx = request.WithUser(ctx, adminUser)
repo = NewAPIKeyRepository(ctx, GetDBXBuilder()) database = GetDBXBuilder()
})
Describe("Put", func() { playerRepo = NewPlayerRepository(ctx, database)
It("sets an ID if it is not set", func() { for idx := range players {
apiKey := &model.APIKey{ err := playerRepo.Put(&players[idx])
UserID: "userid", Expect(err).To(BeNil())
Name: "Test API Key", }
Key: "test-key", repo = NewAPIKeyRepository(ctx, database)
}
err := repo.Put(apiKey)
Expect(err).ToNot(HaveOccurred())
Expect(apiKey.ID).ToNot(BeEmpty())
Expect(apiKey.CreatedAt).ToNot(BeZero())
})
It("keeps existing values", func() {
apiKey := &model.APIKey{
ID: "existing-id",
UserID: "userid",
Name: "Test API Key 2",
Key: "test-key-2",
}
err := repo.Put(apiKey)
Expect(err).ToNot(HaveOccurred())
Expect(apiKey.ID).To(Equal("existing-id"))
Expect(apiKey.CreatedAt).ToNot(BeZero())
})
}) })
Describe("FindByKey", func() { Describe("FindByKey", func() {
It("returns the API key with matching key", func() { It("returns the API key with matching key", func() {
apiKey := &model.APIKey{ apiKey := &model.APIKey{
UserID: "userid", PlayerID: adminPlayer.ID,
Name: "Unique API Key", Name: "Unique API Key",
Key: "unique-test-key",
} }
apiKeyId, err := repo.Save(apiKey)
err := repo.Put(apiKey) Expect(err).ToNot(HaveOccurred())
apiKey, err = repo.Get(apiKeyId)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
result, err := repo.FindByKey("unique-test-key") result, err := repo.FindByKey(apiKey.Key)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(result.ID).To(Equal(apiKey.ID)) Expect(result.ID).To(Equal(apiKey.ID))
Expect(result.Key).To(Equal("unique-test-key")) Expect(result.Key).To(Equal(apiKey.Key))
}) })
It("returns error when key not found", func() { It("returns error when key not found", func() {
@ -77,7 +62,8 @@ var _ = Describe("APIKeyRepository", func() {
Describe("Save", func() { Describe("Save", func() {
It("creates a new API key with a generated key", func() { It("creates a new API key with a generated key", func() {
apiKey := &model.APIKey{ apiKey := &model.APIKey{
Name: "Test API Key Save", Name: "Test API Key Save",
PlayerID: adminPlayer.ID,
} }
id, err := repo.Save(apiKey) id, err := repo.Save(apiKey)
@ -85,25 +71,24 @@ var _ = Describe("APIKeyRepository", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(id).ToNot(BeEmpty()) Expect(id).ToNot(BeEmpty())
Expect(apiKey.Key).To(HavePrefix("nav_")) Expect(apiKey.Key).To(HavePrefix("nav_"))
Expect(apiKey.UserID).To(Equal("userid")) Expect(apiKey.PlayerID).To(Equal(adminPlayer.ID))
Expect(apiKey.Name).To(Equal("Test API Key Save"))
}) })
}) })
Describe("Update", func() { Describe("Update", func() {
It("only updates the name field", func() { It("only updates the name field", func() {
apiKey := &model.APIKey{ apiKey := &model.APIKey{
UserID: "userid", PlayerID: adminPlayer.ID,
Name: "Original Name", Name: "Original Name",
Key: "test-key-for-update",
} }
err := repo.Put(apiKey) _, err := repo.Save(apiKey)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
updateKey := &model.APIKey{ updateKey := &model.APIKey{
Name: "Updated Name", Name: "Updated Name",
Key: "should-not-change", PlayerID: regularPlayer.ID,
UserID: "2222",
} }
err = repo.Update(apiKey.ID, updateKey) err = repo.Update(apiKey.ID, updateKey)
@ -112,8 +97,8 @@ var _ = Describe("APIKeyRepository", func() {
result, err := repo.Get(apiKey.ID) result, err := repo.Get(apiKey.ID)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(result.Name).To(Equal("Updated Name")) Expect(result.Name).To(Equal("Updated Name"))
Expect(result.Key).To(Equal("test-key-for-update")) Expect(result.Key).To(Equal(apiKey.Key))
Expect(result.UserID).To(Equal("userid")) Expect(result.PlayerID).To(Equal(adminPlayer.ID))
}) })
It("returns error when attempting to update non-existent key", func() { It("returns error when attempting to update non-existent key", func() {
@ -125,12 +110,11 @@ var _ = Describe("APIKeyRepository", func() {
Describe("Delete", func() { Describe("Delete", func() {
It("deletes an existing API key", func() { It("deletes an existing API key", func() {
apiKey := &model.APIKey{ apiKey := &model.APIKey{
UserID: "userid", PlayerID: adminPlayer.ID,
Name: "API Key to Delete", Name: "API Key to Delete",
Key: "key-to-delete",
} }
err := repo.Put(apiKey) _, err := repo.Save(apiKey)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = repo.Delete(apiKey.ID) err = repo.Delete(apiKey.ID)
@ -141,6 +125,53 @@ var _ = Describe("APIKeyRepository", func() {
}) })
}) })
Describe("RefreshKey", func() {
It("generates a new key for an existing API key", func() {
apiKey := &model.APIKey{
PlayerID: adminPlayer.ID,
Name: "Test Refresh",
}
_, err := repo.Save(apiKey)
Expect(err).ToNot(HaveOccurred())
originalKey := apiKey.Key
newKey, err := repo.RefreshKey(apiKey.ID)
Expect(err).ToNot(HaveOccurred())
Expect(newKey).ToNot(BeEmpty())
Expect(newKey).ToNot(Equal(originalKey))
Expect(newKey).To(HavePrefix("nav_"))
refreshed, err := repo.Get(apiKey.ID)
Expect(err).ToNot(HaveOccurred())
Expect(refreshed.Key).To(Equal(newKey))
})
It("returns an error for non-existent API key", func() {
_, err := repo.RefreshKey("non-existent-id")
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(rest.ErrNotFound))
})
It("enforces user permissions", func() {
apiKey := &model.APIKey{
PlayerID: adminPlayer.ID,
Name: "Test Permission",
}
_, err := repo.Save(apiKey)
Expect(err).ToNot(HaveOccurred())
nonAdminCtx := log.NewContext(context.TODO())
nonAdminCtx = request.WithUser(nonAdminCtx, regularUser)
nonAdminRepo := NewAPIKeyRepository(nonAdminCtx, database)
_, err = nonAdminRepo.RefreshKey(apiKey.ID)
Expect(err).To(HaveOccurred())
Expect(err).To(MatchError(rest.ErrPermissionDenied))
})
})
Describe("User permissions", func() { Describe("User permissions", func() {
var nonAdminCtx context.Context var nonAdminCtx context.Context
var nonAdminRepo model.APIKeyRepository var nonAdminRepo model.APIKeyRepository
@ -148,8 +179,8 @@ var _ = Describe("APIKeyRepository", func() {
BeforeEach(func() { BeforeEach(func() {
nonAdminCtx = log.NewContext(context.TODO()) nonAdminCtx = log.NewContext(context.TODO())
nonAdminCtx = context.WithValue(nonAdminCtx, "user", model.User{ID: "2222", UserName: "user", IsAdmin: false}) nonAdminCtx = request.WithUser(nonAdminCtx, regularUser)
nonAdminRepo = NewAPIKeyRepository(nonAdminCtx, GetDBXBuilder()) nonAdminRepo = NewAPIKeyRepository(nonAdminCtx, database)
cleanupKeys := func(key string) { cleanupKeys := func(key string) {
foundKey, err := repo.FindByKey(key) foundKey, err := repo.FindByKey(key)
@ -161,20 +192,18 @@ var _ = Describe("APIKeyRepository", func() {
cleanupKeys("user-key") cleanupKeys("user-key")
tmpAdminKey := &model.APIKey{ tmpAdminKey := &model.APIKey{
UserID: "userid", PlayerID: adminPlayer.ID,
Name: "Admin's API Key", Name: "Admin's API Key",
Key: "admin-key",
} }
err := repo.Put(tmpAdminKey) _, err := repo.Save(tmpAdminKey)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
adminKey = *tmpAdminKey adminKey = *tmpAdminKey
userKey := &model.APIKey{ userKey := &model.APIKey{
UserID: "2222", PlayerID: regularPlayer.ID,
Name: "User's API Key", Name: "User's API Key",
Key: "user-key",
} }
err = repo.Put(userKey) _, err = repo.Save(userKey)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -183,7 +212,7 @@ var _ = Describe("APIKeyRepository", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, key := range results { for _, key := range results {
Expect(key.UserID).To(Equal("2222")) Expect(key.PlayerID).To(Equal(regularPlayer.ID))
} }
}) })
@ -193,11 +222,11 @@ var _ = Describe("APIKeyRepository", func() {
userIds := make(map[string]bool) userIds := make(map[string]bool)
for _, key := range results { for _, key := range results {
userIds[key.UserID] = true userIds[key.PlayerID] = true
} }
Expect(userIds).To(HaveKey("userid")) Expect(userIds).To(HaveKey(adminPlayer.ID))
Expect(userIds).To(HaveKey("2222")) Expect(userIds).To(HaveKey(regularPlayer.ID))
}) })
It("a user cannot view/delete/update another user's key", func() { It("a user cannot view/delete/update another user's key", func() {

View File

@ -195,14 +195,17 @@ func (r *userRepository) FindByUsernameWithPassword(username string) (*model.Use
func (r *userRepository) FindByAPIKey(key string) (*model.User, error) { func (r *userRepository) FindByAPIKey(key string) (*model.User, error) {
// find the API key in the database // find the API key in the database
playerRepo := NewPlayerRepository(r.ctx, r.db)
apiKeyRepo := NewAPIKeyRepository(r.ctx, r.db) apiKeyRepo := NewAPIKeyRepository(r.ctx, r.db)
apiKey, err := apiKeyRepo.FindByKey(key) apiKey, err := apiKeyRepo.FindByKey(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
player, err := playerRepo.Get(apiKey.PlayerID)
// Then get the user associated with this API key if err != nil {
return r.Get(apiKey.UserID) return nil, err
}
return r.Get(player.UserId)
} }
func (r *userRepository) UpdateLastLoginAt(id string) error { func (r *userRepository) UpdateLastLoginAt(id string) error {

View File

@ -212,7 +212,7 @@ var _ = Describe("UserRepository", func() {
var existingUser *model.User var existingUser *model.User
BeforeEach(func() { BeforeEach(func() {
existingUser = &model.User{ID: "1", UserName: "johndoe"} existingUser = &model.User{ID: "1", UserName: "johndoe"}
repo = tests.CreateMockUserRepo() repo = tests.CreateMockUserRepo(nil, nil)
err := repo.Put(existingUser) err := repo.Put(existingUser)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })

View File

@ -3,6 +3,8 @@ package nativeapi
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/navidrome/navidrome/utils/req"
"html" "html"
"net/http" "net/http"
"strconv" "strconv"
@ -74,6 +76,7 @@ func (n *Router) routes() http.Handler {
n.addUserLibraryRoute(r) n.addUserLibraryRoute(r)
n.RX(r, "/library", n.libs.NewRepository, true) n.RX(r, "/library", n.libs.NewRepository, true)
}) })
n.addRefreshApiKeyRoute(r)
}) })
return r return r
@ -247,3 +250,37 @@ func adminOnlyMiddleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
func (n *Router) addRefreshApiKeyRoute(r chi.Router) {
r.With(server.URLParamsMiddleware).Post("/apikey/{id}/refresh", func(w http.ResponseWriter, r *http.Request) {
p := req.Params(r)
id, err := p.String(":id")
if err != nil {
msg := fmt.Sprintf("api key id could not be parsed: %s", id)
log.Warn(msg)
http.Error(w, "not found", http.StatusNotFound)
return
}
repo := n.ds.APIKey(r.Context())
_, err = repo.RefreshKey(id)
if err != nil {
log.Error(r.Context(), "error refreshing api key", "id", id, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
updatedKey, err := repo.Get(id)
if err != nil {
log.Error(r.Context(), "error retrieving refreshed api key", "id", id, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = rest.RespondWithJSON(w, http.StatusOK, updatedKey)
if err != nil {
log.Error(r.Context(), "error marshaling refreshed api key", "id", id, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})
}

View File

@ -308,6 +308,8 @@ var _ = Describe("Middlewares", func() {
}) })
When("using api key authentication", func() { When("using api key authentication", func() {
var apiKey *model.APIKey
BeforeEach(func() { BeforeEach(func() {
DeferCleanup(configtest.SetupConfig()) DeferCleanup(configtest.SetupConfig())
@ -318,18 +320,35 @@ var _ = Describe("Middlewares", func() {
} }
_ = ur.Put(user) _ = ur.Put(user)
ar := ds.APIKey(context.TODO()) pr := ds.Player(context.TODO())
apiKey := &model.APIKey{ player := &model.Player{
ID: "api-key-id", ID: "player1",
UserID: user.ID, Name: "Test Player",
Name: "API Key", UserAgent: "Test/1.0",
Key: "api-key", UserId: user.ID,
Client: "test-client",
IP: "127.0.0.1",
LastSeen: time.Now(),
TranscodingId: "",
MaxBitRate: 320,
ReportRealPath: false,
ScrobbleEnabled: true,
} }
_ = ar.Put(apiKey) _ = pr.Put(player)
ar := ds.APIKey(context.TODO())
newApiKey := &model.APIKey{
ID: "api-key-id",
Name: "API Key",
PlayerID: player.ID,
}
apiKeyId, _ := ar.Save(newApiKey)
newApiKey, _ = ar.Get(apiKeyId)
apiKey = newApiKey
}) })
It("passes authentication with correct api key", func() { It("passes authentication with correct api key", func() {
r := newGetRequest("apiKey=api-key") r := newGetRequest("apiKey=" + apiKey.Key)
cp := authenticate(ds)(next) cp := authenticate(ds)(next)
cp.ServeHTTP(w, r) cp.ServeHTTP(w, r)

View File

@ -2,6 +2,7 @@ package tests
import ( import (
"github.com/navidrome/navidrome/model" "github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"strings" "strings"
) )
@ -24,12 +25,14 @@ func (m *MockedAPIKeyRepo) CountAll(_ ...model.QueryOptions) (int64, error) {
return int64(len(m.Data)), nil return int64(len(m.Data)), nil
} }
func (m *MockedAPIKeyRepo) Put(apiKey *model.APIKey) error { func (m *MockedAPIKeyRepo) Save(entity interface{}) (string, error) {
if m.Error != nil { if m.Error != nil {
return m.Error return "", m.Error
} }
apiKey := entity.(*model.APIKey)
apiKey.Key = "nav_" + id.NewRandom()
m.Data[strings.ToLower(apiKey.Key)] = apiKey m.Data[strings.ToLower(apiKey.Key)] = apiKey
return nil return apiKey.ID, nil
} }
func (m *MockedAPIKeyRepo) FindByKey(key string) (*model.APIKey, error) { func (m *MockedAPIKeyRepo) FindByKey(key string) (*model.APIKey, error) {

View File

@ -170,8 +170,9 @@ func (db *MockDataStore) User(ctx context.Context) model.UserRepository {
if db.RealDS != nil { if db.RealDS != nil {
db.MockedUser = db.RealDS.User(ctx) db.MockedUser = db.RealDS.User(ctx)
} else { } else {
playerRepo := db.Player(ctx).(*MockedPlayerRepo)
apiKeyRepo := db.APIKey(ctx).(*MockedAPIKeyRepo) apiKeyRepo := db.APIKey(ctx).(*MockedAPIKeyRepo)
db.MockedUser = CreateMockUserRepo(apiKeyRepo) db.MockedUser = CreateMockUserRepo(playerRepo, apiKeyRepo)
} }
} }
return db.MockedUser return db.MockedUser
@ -193,7 +194,7 @@ func (db *MockDataStore) Player(ctx context.Context) model.PlayerRepository {
if db.RealDS != nil { if db.RealDS != nil {
db.MockedPlayer = db.RealDS.Player(ctx) db.MockedPlayer = db.RealDS.Player(ctx)
} else { } else {
db.MockedPlayer = struct{ model.PlayerRepository }{} db.MockedPlayer = CreateMockPlayerRepo()
} }
} }
return db.MockedPlayer return db.MockedPlayer

74
tests/mock_player_repo.go Normal file
View File

@ -0,0 +1,74 @@
package tests
import (
"encoding/base64"
"github.com/navidrome/navidrome/model"
)
func CreateMockPlayerRepo() *MockedPlayerRepo {
return &MockedPlayerRepo{
Data: make(map[string]*model.Player),
}
}
type MockedPlayerRepo struct {
model.PlayerRepository
Error error
Data map[string]*model.Player
}
func (m *MockedPlayerRepo) Get(id string) (*model.Player, error) {
if m.Error != nil {
return nil, m.Error
}
player, exists := m.Data[id]
if !exists {
return nil, model.ErrNotFound
}
return player, nil
}
func (m *MockedPlayerRepo) FindMatch(userId, client, userAgent string) (*model.Player, error) {
if m.Error != nil {
return nil, m.Error
}
for _, player := range m.Data {
if player.UserId == userId && player.Client == client && player.UserAgent == userAgent {
return player, nil
}
}
return nil, model.ErrNotFound
}
func (m *MockedPlayerRepo) Put(p *model.Player) error {
if m.Error != nil {
return m.Error
}
if p.ID == "" {
p.ID = base64.StdEncoding.EncodeToString([]byte(p.Name + "_" + p.UserId + "_" + p.Client))
}
m.Data[p.ID] = p
return nil
}
func (m *MockedPlayerRepo) CountAll(_ ...model.QueryOptions) (int64, error) {
if m.Error != nil {
return 0, m.Error
}
return int64(len(m.Data)), nil
}
func (m *MockedPlayerRepo) CountByClient(_ ...model.QueryOptions) (map[string]int64, error) {
if m.Error != nil {
return nil, m.Error
}
result := make(map[string]int64)
for _, player := range m.Data {
result[player.Client]++
}
return result, nil
}

View File

@ -30,6 +30,7 @@ type MockedUserRepo struct {
Data map[string]*model.User Data map[string]*model.User
UserLibraries map[string][]int // userID -> libraryIDs UserLibraries map[string][]int // userID -> libraryIDs
APIKeyRepo *MockedAPIKeyRepo APIKeyRepo *MockedAPIKeyRepo
PlayerRepo *MockedPlayerRepo
} }
func (u *MockedUserRepo) CountAll(_ ...model.QueryOptions) (int64, error) { func (u *MockedUserRepo) CountAll(_ ...model.QueryOptions) (int64, error) {
@ -142,8 +143,13 @@ func (u *MockedUserRepo) FindByAPIKey(key string) (*model.User, error) {
return nil, err return nil, err
} }
player, err := u.PlayerRepo.Get(apiKey.PlayerID)
if err != nil {
return nil, err
}
for _, usr := range u.Data { for _, usr := range u.Data {
if usr.ID == apiKey.UserID { if usr.ID == player.UserId {
return usr, nil return usr, nil
} }
} }