From 42b32c79201260a5b251841cf371b2c5b58b9a52 Mon Sep 17 00:00:00 2001 From: Gabe Farrell <90876006+gabehf@users.noreply.github.com> Date: Mon, 26 Jan 2026 13:48:43 -0500 Subject: [PATCH] feat: add api key auth to web api (#183) --- engine/long_test.go | 32 +++++ engine/middleware/authenticate.go | 167 ++++++++++++++++++++++++ engine/middleware/validate.go | 125 ------------------ engine/routes.go | 12 +- internal/cfg/cfg.go | 201 ----------------------------- internal/cfg/getters.go | 206 ++++++++++++++++++++++++++++++ internal/cfg/setters.go | 7 + 7 files changed, 418 insertions(+), 332 deletions(-) create mode 100644 engine/middleware/authenticate.go delete mode 100644 engine/middleware/validate.go create mode 100644 internal/cfg/getters.go create mode 100644 internal/cfg/setters.go diff --git a/engine/long_test.go b/engine/long_test.go index 2ef5d4b..d916117 100644 --- a/engine/long_test.go +++ b/engine/long_test.go @@ -356,6 +356,38 @@ func TestDelete(t *testing.T) { truncateTestData(t) } +func TestLoginGate(t *testing.T) { + + t.Run("Submit Listens", doSubmitListens) + + req, err := http.NewRequest("DELETE", host()+"/apis/web/v1/artist?id=1", nil) + require.NoError(t, err) + req.Header.Add("Authorization", "Token "+apikey) + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, 204, resp.StatusCode) + + cfg.SetLoginGate(true) + + req, err = http.NewRequest("GET", host()+"/apis/web/v1/artist?id=3", nil) + require.NoError(t, err) + // req.Header.Add("Authorization", "Token "+apikey) + resp, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode) + + req, err = http.NewRequest("GET", host()+"/apis/web/v1/artist?id=3", nil) + require.NoError(t, err) + req.Header.Add("Authorization", "Token "+apikey) + resp, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + cfg.SetLoginGate(false) + + truncateTestData(t) +} + func TestAliasesAndSearch(t *testing.T) { t.Run("Submit Listens", doSubmitListens) diff --git a/engine/middleware/authenticate.go b/engine/middleware/authenticate.go new file mode 100644 index 0000000..a435473 --- /dev/null +++ b/engine/middleware/authenticate.go @@ -0,0 +1,167 @@ +package middleware + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/gabehf/koito/internal/cfg" + "github.com/gabehf/koito/internal/db" + "github.com/gabehf/koito/internal/logger" + "github.com/gabehf/koito/internal/models" + "github.com/gabehf/koito/internal/utils" + "github.com/google/uuid" +) + +type MiddlwareContextKey string + +const ( + UserContextKey MiddlwareContextKey = "user" + apikeyContextKey MiddlwareContextKey = "apikeyID" +) + +type AuthMode int + +const ( + AuthModeSessionCookie AuthMode = iota + AuthModeAPIKey + AuthModeSessionOrAPIKey + AuthModeLoginGate +) + +func Authenticate(store db.DB, mode AuthMode) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + l := logger.FromContext(ctx) + + var user *models.User + var err error + + switch mode { + case AuthModeSessionCookie: + user, err = validateSession(ctx, store, r) + + case AuthModeAPIKey: + user, err = validateAPIKey(ctx, store, r) + + case AuthModeSessionOrAPIKey: + user, err = validateSession(ctx, store, r) + if err != nil || user == nil { + user, err = validateAPIKey(ctx, store, r) + } + + case AuthModeLoginGate: + if cfg.LoginGate() { + user, err = validateSession(ctx, store, r) + if err != nil || user == nil { + user, err = validateAPIKey(ctx, store, r) + } + } else { + next.ServeHTTP(w, r) + } + } + + if err != nil { + l.Err(err).Msg("authentication failed") + utils.WriteError(w, "unauthorized", http.StatusUnauthorized) + return + } + + if user == nil { + utils.WriteError(w, "unauthorized", http.StatusUnauthorized) + return + } + + if user != nil { + ctx = context.WithValue(ctx, UserContextKey, user) + r = r.WithContext(ctx) + } + + next.ServeHTTP(w, r) + }) + } +} + +func validateSession(ctx context.Context, store db.DB, r *http.Request) (*models.User, error) { + l := logger.FromContext(r.Context()) + + l.Debug().Msgf("ValidateSession: Checking user authentication via session cookie") + + cookie, err := r.Cookie("koito_session") + var sid uuid.UUID + if err == nil { + sid, err = uuid.Parse(cookie.Value) + if err != nil { + l.Err(err).Msg("ValidateSession: Could not parse UUID from session cookie") + return nil, errors.New("session cookie is invalid") + } + } else { + l.Debug().Msgf("ValidateSession: No session cookie found; attempting API key authentication") + return nil, errors.New("session cookie is missing") + } + + l.Debug().Msg("ValidateSession: Retrieved login cookie from request") + + u, err := store.GetUserBySession(r.Context(), sid) + if err != nil { + l.Err(fmt.Errorf("ValidateSession: %w", err)).Msg("Error accessing database") + return nil, errors.New("internal server error") + } + if u == nil { + l.Debug().Msg("ValidateSession: No user with session id found") + return nil, errors.New("no user with session id found") + } + + ctx = context.WithValue(r.Context(), UserContextKey, u) + r = r.WithContext(ctx) + + l.Debug().Msgf("ValidateSession: Refreshing session for user '%s'", u.Username) + + store.RefreshSession(r.Context(), sid, time.Now().Add(30*24*time.Hour)) + + l.Debug().Msgf("ValidateSession: Refreshed session for user '%s'", u.Username) + + return u, nil +} + +func validateAPIKey(ctx context.Context, store db.DB, r *http.Request) (*models.User, error) { + l := logger.FromContext(ctx) + + l.Debug().Msg("ValidateApiKey: Checking if user is already authenticated") + + authH := r.Header.Get("Authorization") + var token string + if strings.HasPrefix(strings.ToLower(authH), "token ") { + token = strings.TrimSpace(authH[6:]) // strip "Token " + } else { + l.Error().Msg("ValidateApiKey: Authorization header must be formatted 'Token {token}'") + return nil, errors.New("authorization header is invalid") + } + + u, err := store.GetUserByApiKey(ctx, token) + if err != nil { + l.Err(err).Msg("ValidateApiKey: Failed to get user from database using api key") + return nil, errors.New("internal server error") + } + if u == nil { + l.Debug().Msg("ValidateApiKey: API key does not exist") + return nil, errors.New("authorization token is invalid") + } + + ctx = context.WithValue(r.Context(), UserContextKey, u) + r = r.WithContext(ctx) + + return u, nil +} + +func GetUserFromContext(ctx context.Context) *models.User { + user, ok := ctx.Value(UserContextKey).(*models.User) + if !ok { + return nil + } + return user +} diff --git a/engine/middleware/validate.go b/engine/middleware/validate.go deleted file mode 100644 index b3e1369..0000000 --- a/engine/middleware/validate.go +++ /dev/null @@ -1,125 +0,0 @@ -package middleware - -import ( - "context" - "fmt" - "net/http" - "strings" - "time" - - "github.com/gabehf/koito/internal/db" - "github.com/gabehf/koito/internal/logger" - "github.com/gabehf/koito/internal/models" - "github.com/gabehf/koito/internal/utils" - "github.com/google/uuid" -) - -type MiddlwareContextKey string - -const ( - UserContextKey MiddlwareContextKey = "user" - apikeyContextKey MiddlwareContextKey = "apikeyID" -) - -func ValidateSession(store db.DB) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - l := logger.FromContext(r.Context()) - - l.Debug().Msgf("ValidateSession: Checking user authentication via session cookie") - - cookie, err := r.Cookie("koito_session") - var sid uuid.UUID - if err == nil { - sid, err = uuid.Parse(cookie.Value) - if err != nil { - l.Err(err).Msg("ValidateSession: Could not parse UUID from session cookie") - utils.WriteError(w, "session cookie is invalid", http.StatusUnauthorized) - return - } - } else { - l.Debug().Msgf("ValidateSession: No session cookie found; attempting API key authentication") - utils.WriteError(w, "session cookie is missing", http.StatusUnauthorized) - return - } - - l.Debug().Msg("ValidateSession: Retrieved login cookie from request") - - u, err := store.GetUserBySession(r.Context(), sid) - if err != nil { - l.Err(fmt.Errorf("ValidateSession: %w", err)).Msg("Error accessing database") - utils.WriteError(w, "internal server error", http.StatusInternalServerError) - return - } - if u == nil { - l.Debug().Msg("ValidateSession: No user with session id found") - utils.WriteError(w, "unauthorized", http.StatusUnauthorized) - return - } - - ctx := context.WithValue(r.Context(), UserContextKey, u) - r = r.WithContext(ctx) - - l.Debug().Msgf("ValidateSession: Refreshing session for user '%s'", u.Username) - - store.RefreshSession(r.Context(), sid, time.Now().Add(30*24*time.Hour)) - - l.Debug().Msgf("ValidateSession: Refreshed session for user '%s'", u.Username) - - next.ServeHTTP(w, r) - }) - } -} - -func ValidateApiKey(store db.DB) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - l := logger.FromContext(ctx) - - l.Debug().Msg("ValidateApiKey: Checking if user is already authenticated") - - u := GetUserFromContext(ctx) - if u != nil { - l.Debug().Msg("ValidateApiKey: User is already authenticated; skipping API key authentication") - next.ServeHTTP(w, r) - return - } - - authh := r.Header.Get("Authorization") - var token string - if strings.HasPrefix(strings.ToLower(authh), "token ") { - token = strings.TrimSpace(authh[6:]) // strip "Token " - } else { - l.Error().Msg("ValidateApiKey: Authorization header must be formatted 'Token {token}'") - utils.WriteError(w, "unauthorized", http.StatusUnauthorized) - return - } - - u, err := store.GetUserByApiKey(ctx, token) - if err != nil { - l.Err(err).Msg("Failed to get user from database using api key") - utils.WriteError(w, "internal server error", http.StatusInternalServerError) - return - } - if u == nil { - l.Debug().Msg("Api key does not exist") - utils.WriteError(w, "unauthorized", http.StatusUnauthorized) - return - } - - ctx = context.WithValue(r.Context(), UserContextKey, u) - r = r.WithContext(ctx) - - next.ServeHTTP(w, r) - }) - } -} - -func GetUserFromContext(ctx context.Context) *models.User { - user, ok := ctx.Value(UserContextKey).(*models.User) - if !ok { - return nil - } - return user -} diff --git a/engine/routes.go b/engine/routes.go index e1c5fda..c62edf5 100644 --- a/engine/routes.go +++ b/engine/routes.go @@ -38,9 +38,7 @@ func bindRoutes( r.Get("/config", handlers.GetCfgHandler()) r.Group(func(r chi.Router) { - if cfg.LoginGate() { - r.Use(middleware.ValidateSession(db)) - } + r.Use(middleware.Authenticate(db, middleware.AuthModeLoginGate)) r.Get("/artist", handlers.GetArtistHandler(db)) r.Get("/artists", handlers.GetArtistsForItemHandler(db)) r.Get("/album", handlers.GetAlbumHandler(db)) @@ -79,7 +77,7 @@ func bindRoutes( }) r.Group(func(r chi.Router) { - r.Use(middleware.ValidateSession(db)) + r.Use(middleware.Authenticate(db, middleware.AuthModeSessionOrAPIKey)) r.Get("/export", handlers.ExportHandler(db)) r.Post("/replace-image", handlers.ReplaceImageHandler(db)) r.Patch("/album", handlers.UpdateAlbumHandler(db)) @@ -111,8 +109,10 @@ func bindRoutes( AllowedHeaders: []string{"Content-Type", "Authorization"}, })) - r.With(middleware.ValidateApiKey(db)).Post("/submit-listens", handlers.LbzSubmitListenHandler(db, mbz)) - r.With(middleware.ValidateApiKey(db)).Get("/validate-token", handlers.LbzValidateTokenHandler(db)) + r.With(middleware.Authenticate(db, middleware.AuthModeAPIKey)). + Post("/submit-listens", handlers.LbzSubmitListenHandler(db, mbz)) + r.With(middleware.Authenticate(db, middleware.AuthModeAPIKey)). + Get("/validate-token", handlers.LbzValidateTokenHandler(db)) }) // serve react client diff --git a/internal/cfg/cfg.go b/internal/cfg/cfg.go index e74d6b9..0cfc7bb 100644 --- a/internal/cfg/cfg.go +++ b/internal/cfg/cfg.go @@ -244,204 +244,3 @@ func parseBool(s string) bool { return false } } - -// Global accessors for configuration values - -func UserAgent() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.userAgent -} - -func ListenAddr() string { - lock.RLock() - defer lock.RUnlock() - return fmt.Sprintf("%s:%d", globalConfig.bindAddr, globalConfig.listenPort) -} - -func ConfigDir() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.configDir -} - -func DatabaseUrl() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.databaseUrl -} - -func MusicBrainzUrl() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.musicBrainzUrl -} - -func MusicBrainzRateLimit() int { - lock.RLock() - defer lock.RUnlock() - return globalConfig.musicBrainzRateLimit -} - -func LogLevel() int { - lock.RLock() - defer lock.RUnlock() - return globalConfig.logLevel -} - -func StructuredLogging() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.structuredLogging -} - -func LbzRelayEnabled() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.lbzRelayEnabled -} - -func LbzRelayUrl() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.lbzRelayUrl -} - -func LbzRelayToken() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.lbzRelayToken -} - -func DefaultPassword() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.defaultPw -} - -func DefaultUsername() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.defaultUsername -} - -func DefaultTheme() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.defaultTheme -} - -func FullImageCacheEnabled() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.enableFullImageCache -} - -func DeezerDisabled() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.disableDeezer -} - -func CoverArtArchiveDisabled() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.disableCAA -} - -func MusicBrainzDisabled() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.disableMusicBrainz -} - -func SubsonicEnabled() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.subsonicEnabled -} - -func SubsonicUrl() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.subsonicUrl -} - -func SubsonicParams() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.subsonicParams -} - -func LastFMApiKey() string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.lastfmApiKey -} - -func SkipImport() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.skipImport -} - -func AllowedHosts() []string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.allowedHosts -} - -func AllowAllHosts() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.allowAllHosts -} - -func AllowedOrigins() []string { - lock.RLock() - defer lock.RUnlock() - return globalConfig.allowedOrigins -} - -func RateLimitDisabled() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.disableRateLimit -} - -func ThrottleImportMs() int { - lock.RLock() - defer lock.RUnlock() - return globalConfig.importThrottleMs -} - -// returns the before, after times, in that order -func ImportWindow() (time.Time, time.Time) { - lock.RLock() - defer lock.RUnlock() - return globalConfig.importBefore, globalConfig.importAfter -} - -func FetchImagesDuringImport() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.fetchImageDuringImport -} - -func ArtistSeparators() []*regexp.Regexp { - lock.RLock() - defer lock.RUnlock() - return globalConfig.artistSeparators -} - -func LoginGate() bool { - lock.RLock() - defer lock.RUnlock() - return globalConfig.loginGate -} - -func ForceTZ() *time.Location { - lock.RLock() - defer lock.RUnlock() - return globalConfig.forceTZ -} diff --git a/internal/cfg/getters.go b/internal/cfg/getters.go new file mode 100644 index 0000000..596ca9d --- /dev/null +++ b/internal/cfg/getters.go @@ -0,0 +1,206 @@ +package cfg + +import ( + "fmt" + "regexp" + "time" +) + +func UserAgent() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.userAgent +} + +func ListenAddr() string { + lock.RLock() + defer lock.RUnlock() + return fmt.Sprintf("%s:%d", globalConfig.bindAddr, globalConfig.listenPort) +} + +func ConfigDir() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.configDir +} + +func DatabaseUrl() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.databaseUrl +} + +func MusicBrainzUrl() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.musicBrainzUrl +} + +func MusicBrainzRateLimit() int { + lock.RLock() + defer lock.RUnlock() + return globalConfig.musicBrainzRateLimit +} + +func LogLevel() int { + lock.RLock() + defer lock.RUnlock() + return globalConfig.logLevel +} + +func StructuredLogging() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.structuredLogging +} + +func LbzRelayEnabled() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.lbzRelayEnabled +} + +func LbzRelayUrl() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.lbzRelayUrl +} + +func LbzRelayToken() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.lbzRelayToken +} + +func DefaultPassword() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.defaultPw +} + +func DefaultUsername() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.defaultUsername +} + +func DefaultTheme() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.defaultTheme +} + +func FullImageCacheEnabled() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.enableFullImageCache +} + +func DeezerDisabled() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.disableDeezer +} + +func CoverArtArchiveDisabled() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.disableCAA +} + +func MusicBrainzDisabled() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.disableMusicBrainz +} + +func SubsonicEnabled() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.subsonicEnabled +} + +func SubsonicUrl() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.subsonicUrl +} + +func SubsonicParams() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.subsonicParams +} + +func LastFMApiKey() string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.lastfmApiKey +} + +func SkipImport() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.skipImport +} + +func AllowedHosts() []string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.allowedHosts +} + +func AllowAllHosts() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.allowAllHosts +} + +func AllowedOrigins() []string { + lock.RLock() + defer lock.RUnlock() + return globalConfig.allowedOrigins +} + +func RateLimitDisabled() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.disableRateLimit +} + +func ThrottleImportMs() int { + lock.RLock() + defer lock.RUnlock() + return globalConfig.importThrottleMs +} + +// returns the before, after times, in that order +func ImportWindow() (time.Time, time.Time) { + lock.RLock() + defer lock.RUnlock() + return globalConfig.importBefore, globalConfig.importAfter +} + +func FetchImagesDuringImport() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.fetchImageDuringImport +} + +func ArtistSeparators() []*regexp.Regexp { + lock.RLock() + defer lock.RUnlock() + return globalConfig.artistSeparators +} + +func LoginGate() bool { + lock.RLock() + defer lock.RUnlock() + return globalConfig.loginGate +} + +func ForceTZ() *time.Location { + lock.RLock() + defer lock.RUnlock() + return globalConfig.forceTZ +} diff --git a/internal/cfg/setters.go b/internal/cfg/setters.go new file mode 100644 index 0000000..8458780 --- /dev/null +++ b/internal/cfg/setters.go @@ -0,0 +1,7 @@ +package cfg + +func SetLoginGate(val bool) { + lock.Lock() + defer lock.Unlock() + globalConfig.loginGate = val +}