chore: initial public commit

This commit is contained in:
Gabe Farrell 2025-06-11 19:45:39 -04:00
commit fc9054b78c
250 changed files with 32809 additions and 0 deletions

225
engine/engine.go Normal file
View file

@ -0,0 +1,225 @@
package engine
import (
"context"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"path"
"strings"
"sync/atomic"
"syscall"
"time"
"github.com/gabehf/koito/engine/middleware"
"github.com/gabehf/koito/internal/catalog"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/db/psql"
"github.com/gabehf/koito/internal/images"
"github.com/gabehf/koito/internal/importer"
"github.com/gabehf/koito/internal/logger"
mbz "github.com/gabehf/koito/internal/mbz"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/utils"
"github.com/go-chi/chi/v5"
chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/rs/zerolog"
)
const Version = "dev"
func Run(
getenv func(string) string,
w io.Writer,
) error {
err := cfg.Load(getenv)
if err != nil {
return fmt.Errorf("failed to load configuration: %v", err)
}
l := logger.Get()
if cfg.StructuredLogging() {
*l = l.Output(w)
} else {
*l = l.Output(zerolog.ConsoleWriter{
Out: w,
TimeFormat: time.RFC3339,
// FormatLevel: func(i interface{}) string {
// return strings.ToUpper(fmt.Sprintf("[%s]", i))
// },
FormatMessage: func(i interface{}) string {
return fmt.Sprintf("\u001b[30;1m>\u001b[0m %s |", i)
},
})
}
ctx := logger.NewContext(l)
l.Info().Msgf("Koito %s", Version)
_, err = os.Stat(cfg.ConfigDir())
if err != nil {
l.Info().Msgf("Creating config dir: %s", cfg.ConfigDir())
err = os.MkdirAll(cfg.ConfigDir(), 0744)
if err != nil {
l.Error().Err(err).Msg("Failed to create config directory")
return err
}
}
l.Info().Msgf("Using config dir: %s", cfg.ConfigDir())
_, err = os.Stat(path.Join(cfg.ConfigDir(), "import"))
if err != nil {
l.Debug().Msgf("Creating import dir: %s", path.Join(cfg.ConfigDir(), "import"))
err = os.Mkdir(path.Join(cfg.ConfigDir(), "import"), 0744)
if err != nil {
l.Error().Err(err).Msg("Failed to create import directory")
return err
}
}
var store *psql.Psql
store, err = psql.New()
for err != nil {
l.Error().Err(err).Msg("Failed to connect to database; retrying in 5 seconds")
time.Sleep(5 * time.Second)
store, err = psql.New()
}
defer store.Close(ctx)
var mbzC mbz.MusicBrainzCaller
if !cfg.MusicBrainzDisabled() {
mbzC = mbz.NewMusicBrainzClient()
} else {
mbzC = &mbz.MbzErrorCaller{}
}
images.Initialize(images.ImageSourceOpts{
UserAgent: "Koito v0.0.1 (contact@koito.app)",
EnableCAA: !cfg.CoverArtArchiveDisabled(),
EnableDeezer: !cfg.DeezerDisabled(),
})
userCount, _ := store.CountUsers(ctx)
if userCount < 1 {
l.Debug().Msg("Creating default user...")
user, err := store.SaveUser(ctx, db.SaveUserOpts{
Username: cfg.DefaultUsername(),
Password: cfg.DefaultPassword(),
Role: models.UserRoleAdmin,
})
if err != nil {
l.Fatal().AnErr("error", err).Msg("Failed to save default user in database")
}
apikey, err := utils.GenerateRandomString(48)
if err != nil {
l.Fatal().AnErr("error", err).Msg("Failed to generate default api key")
}
label := "Default"
_, err = store.SaveApiKey(ctx, db.SaveApiKeyOpts{
Key: apikey,
UserID: user.ID,
Label: label,
})
if err != nil {
l.Fatal().AnErr("error", err).Msg("Failed to save default api key in database")
}
l.Info().Msgf("Default user has been created. Login: %s : %s", cfg.DefaultUsername(), cfg.DefaultPassword())
}
if cfg.AllowAllHosts() {
l.Warn().Msg("Your configuration allows requests from all hosts. This is a potential security risk!")
} else if len(cfg.AllowedHosts()) == 0 || cfg.AllowedHosts()[0] == "" {
l.Warn().Msgf("You are currently not allowing any hosts! Did you forget to set the %s variable?", cfg.ALLOWED_HOSTS_ENV)
} else {
l.Debug().Msgf("Allowing hosts: %v", cfg.AllowedHosts())
}
var ready atomic.Bool
mux := chi.NewRouter()
// bind general middleware to mux
mux.Use(middleware.WithRequestID)
mux.Use(middleware.Logger(l))
mux.Use(chimiddleware.Recoverer)
mux.Use(chimiddleware.RealIP)
// call router binds on mux
bindRoutes(mux, &ready, store, mbzC)
httpServer := &http.Server{
Addr: cfg.ListenAddr(),
Handler: mux,
}
go func() {
ready.Store(true) // signal readiness
l.Info().Msg("listening on " + cfg.ListenAddr())
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
l.Fatal().AnErr("error", err).Msg("Error when running ListenAndServe")
}
}()
// Import
if !cfg.SkipImport() {
go func() {
RunImporter(l, store)
}()
}
l.Info().Msg("Pruning orphaned images...")
go catalog.PruneOrphanedImages(logger.NewContext(l), store)
// Wait for interrupt signal to gracefully shutdown the server with a timeout of 10 seconds.
// Use a buffered channel to avoid missing signals as recommended for signal.Notify
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
<-quit
l.Info().Msg("Received server shutdown notice")
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
l.Info().Msg("waiting for all processes to finish...")
mbzC.Shutdown()
if err := httpServer.Shutdown(ctx); err != nil {
return err
}
l.Info().Msg("shutdown successful")
return nil
}
func RunImporter(l *zerolog.Logger, store db.DB) {
l.Debug().Msg("Checking for import files...")
files, err := os.ReadDir(path.Join(cfg.ConfigDir(), "import"))
if err != nil {
l.Err(err).Msg("Failed to read files from import dir")
}
if len(files) > 0 {
l.Info().Msg("Files found in import directory. Attempting to import...")
} else {
return
}
defer func() {
if r := recover(); r != nil {
l.Error().Interface("recover", r).Msg("Panic when importing files")
}
}()
for _, file := range files {
if file.IsDir() {
continue
}
if strings.Contains(file.Name(), "Streaming_History_Audio") {
l.Info().Msgf("Import file %s detecting as being Spotify export", file.Name())
err := importer.ImportSpotifyFile(logger.NewContext(l), store, file.Name())
if err != nil {
l.Err(err).Msgf("Failed to import file: %s", file.Name())
}
} else if strings.Contains(file.Name(), "maloja") {
l.Info().Msgf("Import file %s detecting as being Maloja export", file.Name())
err := importer.ImportMalojaFile(logger.NewContext(l), store, file.Name())
if err != nil {
l.Err(err).Msgf("Failed to import file: %s", file.Name())
}
} else {
l.Warn().Msgf("File %s not recognized as a valid import file; make sure it is valid and named correctly", file.Name())
}
}
}

144
engine/engine_test.go Normal file
View file

@ -0,0 +1,144 @@
package engine_test
import (
"context"
"fmt"
"log"
"net"
"net/http"
"os"
"strconv"
"testing"
"time"
"github.com/gabehf/koito/engine"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db/psql"
"github.com/gabehf/koito/internal/utils"
"github.com/ory/dockertest/v3"
)
var store *psql.Psql
func getTestGetenv(resource *dockertest.Resource) func(string) string {
dir, err := utils.GenerateRandomString(8)
if err != nil {
panic(err)
}
listener, err := net.Listen("tcp", ":0")
if err != nil {
panic(fmt.Errorf("failed to get an open port: %w", err))
}
defer listener.Close()
port := strconv.Itoa(listener.Addr().(*net.TCPAddr).Port)
return func(env string) string {
switch env {
case cfg.ENABLE_STRUCTURED_LOGGING_ENV:
return "true"
case cfg.LOG_LEVEL_ENV:
return "debug"
case cfg.DATABASE_URL_ENV:
return fmt.Sprintf("postgres://postgres:secret@localhost:%s", resource.GetPort("5432/tcp"))
case cfg.DEFAULT_PASSWORD_ENV:
return "testuser123"
case cfg.DEFAULT_USERNAME_ENV:
return "test"
case cfg.CONFIG_DIR_ENV:
return dir
case cfg.LISTEN_PORT_ENV:
return port
case cfg.ALLOWED_HOSTS_ENV:
return "*"
case cfg.DISABLE_DEEZER_ENV, cfg.DISABLE_COVER_ART_ARCHIVE_ENV, cfg.DISABLE_MUSICBRAINZ_ENV, cfg.SKIP_IMPORT_ENV:
return "true"
default:
return ""
}
}
}
func TestMain(m *testing.M) {
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not construct pool: %s", err)
}
// uses pool to try to connect to Docker
err = pool.Client.Ping()
if err != nil {
log.Fatalf("Could not connect to Docker: %s", err)
}
// pulls an image, creates a container based on it and runs it
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret"})
if err != nil {
log.Fatalf("Could not start resource: %s", err)
}
getenv := getTestGetenv(resource)
err = cfg.Load(getenv)
if err != nil {
log.Fatalf("Could not load cfg: %s", err)
}
// exponential backoff-retry, because the application in the container might not be ready to accept connections yet
if err := pool.Retry(func() error {
var err error
store, err = psql.New()
if err != nil {
log.Println("Failed to connect to test database, retrying...")
return err
}
return store.Ping(context.Background())
}); err != nil {
log.Fatalf("Could not connect to database: %s", err)
}
go engine.Run(getenv, os.Stdout)
// Wait until the web server is reachable
for i := 0; i < 20; i++ {
url := fmt.Sprintf("http://%s/apis/web/v1/health", cfg.ListenAddr())
client := &http.Client{
Timeout: 2 * time.Second, // Set your desired timeout
}
resp, err := client.Get(url)
if err != nil {
if i >= 19 {
log.Fatalf("Web server is not reachable: %s", err)
}
log.Printf("Failed to connect to web server at %s, retrying... (%d/20)", url, i+1)
time.Sleep(1 * time.Second)
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
err = nil
break
}
log.Printf("Unexpected status code at %s, retrying... (%d/20)", url, i+1)
time.Sleep(1 * time.Second)
}
code := m.Run()
// You can't defer this because os.Exit doesn't care for defer
if err := pool.Purge(resource); err != nil {
log.Fatalf("Could not purge resource: %s", err)
}
err = os.RemoveAll(cfg.ConfigDir())
if err != nil {
log.Fatalf("Could not remove temporary config dir: %v", err)
}
os.Exit(code)
}
func host() string {
return fmt.Sprintf("http://%s", cfg.ListenAddr())
}

270
engine/handlers/alias.go Normal file
View file

@ -0,0 +1,270 @@
package handlers
import (
"net/http"
"strconv"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/utils"
)
// GetAliasesHandler retrieves all aliases for a given artist or album ID.
func GetAliasesHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
// Parse query parameters
artistIDStr := r.URL.Query().Get("artist_id")
albumIDStr := r.URL.Query().Get("album_id")
trackIDStr := r.URL.Query().Get("track_id")
if artistIDStr == "" && albumIDStr == "" && trackIDStr == "" {
utils.WriteError(w, "artist_id, album_id, or track_id must be provided", http.StatusBadRequest)
return
}
var aliases []models.Alias
if artistIDStr != "" {
artistID, err := strconv.Atoi(artistIDStr)
if err != nil {
utils.WriteError(w, "invalid artist_id", http.StatusBadRequest)
return
}
aliases, err = store.GetAllArtistAliases(ctx, int32(artistID))
if err != nil {
l.Err(err).Msg("Failed to get artist aliases")
utils.WriteError(w, "failed to retrieve aliases", http.StatusInternalServerError)
return
}
} else if albumIDStr != "" {
albumID, err := strconv.Atoi(albumIDStr)
if err != nil {
utils.WriteError(w, "invalid album_id", http.StatusBadRequest)
return
}
aliases, err = store.GetAllAlbumAliases(ctx, int32(albumID))
if err != nil {
l.Err(err).Msg("Failed to get artist aliases")
utils.WriteError(w, "failed to retrieve aliases", http.StatusInternalServerError)
return
}
} else if trackIDStr != "" {
trackID, err := strconv.Atoi(trackIDStr)
if err != nil {
utils.WriteError(w, "invalid track_id", http.StatusBadRequest)
return
}
aliases, err = store.GetAllTrackAliases(ctx, int32(trackID))
if err != nil {
l.Err(err).Msg("Failed to get artist aliases")
utils.WriteError(w, "failed to retrieve aliases", http.StatusInternalServerError)
return
}
}
utils.WriteJSON(w, http.StatusOK, aliases)
}
}
// DeleteAliasHandler deletes an alias for a given artist or album ID.
func DeleteAliasHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
// Parse query parameters
artistIDStr := r.URL.Query().Get("artist_id")
albumIDStr := r.URL.Query().Get("album_id")
trackIDStr := r.URL.Query().Get("track_id")
alias := r.URL.Query().Get("alias")
if alias == "" || (artistIDStr == "" && albumIDStr == "" && trackIDStr == "") {
utils.WriteError(w, "alias and artist_id, album_id, or track_id must be provided", http.StatusBadRequest)
return
}
if utils.MoreThanOneString(artistIDStr, albumIDStr, trackIDStr) {
utils.WriteError(w, "only one of artist_id, album_id, or track_id can be provided at a time", http.StatusBadRequest)
return
}
if artistIDStr != "" {
artistID, err := strconv.Atoi(artistIDStr)
if err != nil {
utils.WriteError(w, "invalid artist_id", http.StatusBadRequest)
return
}
err = store.DeleteArtistAlias(ctx, int32(artistID), alias)
if err != nil {
l.Err(err).Msg("Failed to delete alias")
utils.WriteError(w, "failed to delete alias", http.StatusInternalServerError)
return
}
} else if albumIDStr != "" {
albumID, err := strconv.Atoi(albumIDStr)
if err != nil {
utils.WriteError(w, "invalid album_id", http.StatusBadRequest)
return
}
err = store.DeleteAlbumAlias(ctx, int32(albumID), alias)
if err != nil {
l.Err(err).Msg("Failed to delete alias")
utils.WriteError(w, "failed to delete alias", http.StatusInternalServerError)
return
}
} else if trackIDStr != "" {
trackID, err := strconv.Atoi(trackIDStr)
if err != nil {
utils.WriteError(w, "invalid album_id", http.StatusBadRequest)
return
}
err = store.DeleteTrackAlias(ctx, int32(trackID), alias)
if err != nil {
l.Err(err).Msg("Failed to delete alias")
utils.WriteError(w, "failed to delete alias", http.StatusInternalServerError)
return
}
}
w.WriteHeader(http.StatusNoContent)
}
}
// CreateAliasHandler creates new aliases for a given artist, album, or track.
func CreateAliasHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
err := r.ParseForm()
if err != nil {
utils.WriteError(w, "invalid request body", http.StatusBadRequest)
return
}
artistIDStr := r.URL.Query().Get("artist_id")
albumIDStr := r.URL.Query().Get("album_id")
trackIDStr := r.URL.Query().Get("track_id")
if artistIDStr == "" && albumIDStr == "" && trackIDStr == "" {
utils.WriteError(w, "artist_id, album_id, or track_id must be provided", http.StatusBadRequest)
return
}
if utils.MoreThanOneString(artistIDStr, albumIDStr, trackIDStr) {
utils.WriteError(w, "only one of artist_id, album_id, or track_id can be provided at a time", http.StatusBadRequest)
return
}
alias := r.FormValue("alias")
if alias == "" {
utils.WriteError(w, "alias must be provided", http.StatusBadRequest)
return
}
if artistIDStr != "" {
artistID, err := strconv.Atoi(artistIDStr)
if err != nil {
utils.WriteError(w, "invalid artist_id", http.StatusBadRequest)
return
}
err = store.SaveArtistAliases(ctx, int32(artistID), []string{alias}, "Manual")
if err != nil {
l.Err(err).Msg("Failed to save alias")
utils.WriteError(w, "failed to save alias", http.StatusInternalServerError)
return
}
} else if albumIDStr != "" {
albumID, err := strconv.Atoi(albumIDStr)
if err != nil {
utils.WriteError(w, "invalid album_id", http.StatusBadRequest)
return
}
err = store.SaveAlbumAliases(ctx, int32(albumID), []string{alias}, "Manual")
if err != nil {
l.Err(err).Msg("Failed to save alias")
utils.WriteError(w, "failed to save alias", http.StatusInternalServerError)
return
}
} else if trackIDStr != "" {
trackID, err := strconv.Atoi(trackIDStr)
if err != nil {
utils.WriteError(w, "invalid track_id", http.StatusBadRequest)
return
}
err = store.SaveTrackAliases(ctx, int32(trackID), []string{alias}, "Manual")
if err != nil {
l.Err(err).Msg("Failed to save alias")
utils.WriteError(w, "failed to save alias", http.StatusInternalServerError)
return
}
}
w.WriteHeader(http.StatusCreated)
}
}
// sets the primary alias for albums, artists, and tracks
func SetPrimaryAliasHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
// Parse query parameters
artistIDStr := r.URL.Query().Get("artist_id")
albumIDStr := r.URL.Query().Get("album_id")
trackIDStr := r.URL.Query().Get("track_id")
alias := r.URL.Query().Get("alias")
if alias == "" || (artistIDStr == "" && albumIDStr == "" && trackIDStr == "") {
utils.WriteError(w, "alias and artist_id, album_id, or track_id must be provided", http.StatusBadRequest)
return
}
if utils.MoreThanOneString(artistIDStr, albumIDStr, trackIDStr) {
utils.WriteError(w, "only one of artist_id, album_id, or track_id can be provided", http.StatusBadRequest)
return
}
if artistIDStr != "" {
artistID, err := strconv.Atoi(artistIDStr)
if err != nil {
utils.WriteError(w, "invalid artist_id", http.StatusBadRequest)
return
}
err = store.SetPrimaryArtistAlias(ctx, int32(artistID), alias)
if err != nil {
l.Err(err).Msg("Failed to set primary alias")
utils.WriteError(w, "failed to set primary alias", http.StatusInternalServerError)
return
}
} else if albumIDStr != "" {
albumID, err := strconv.Atoi(albumIDStr)
if err != nil {
utils.WriteError(w, "invalid album_id", http.StatusBadRequest)
return
}
err = store.SetPrimaryAlbumAlias(ctx, int32(albumID), alias)
if err != nil {
l.Err(err).Msg("Failed to set primary alias")
utils.WriteError(w, "failed to set primary alias", http.StatusInternalServerError)
return
}
} else if trackIDStr != "" {
trackID, err := strconv.Atoi(trackIDStr)
if err != nil {
utils.WriteError(w, "invalid track_id", http.StatusBadRequest)
return
}
err = store.SetPrimaryTrackAlias(ctx, int32(trackID), alias)
if err != nil {
l.Err(err).Msg("Failed to set primary alias")
utils.WriteError(w, "failed to set primary alias", http.StatusInternalServerError)
return
}
}
w.WriteHeader(http.StatusNoContent)
}
}

153
engine/handlers/apikeys.go Normal file
View file

@ -0,0 +1,153 @@
package handlers
import (
"net/http"
"strconv"
"github.com/gabehf/koito/engine/middleware"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
func GenerateApiKeyHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
user := middleware.GetUserFromContext(ctx)
if user == nil {
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
r.ParseForm()
label := r.FormValue("label")
if label == "" {
utils.WriteError(w, "label is required", http.StatusBadRequest)
return
}
apiKey, err := utils.GenerateRandomString(48)
if err != nil {
l.Err(err).Msg("Failed to generate API key")
utils.WriteError(w, "failed to generate api key", http.StatusInternalServerError)
return
}
opts := db.SaveApiKeyOpts{
UserID: user.ID,
Key: apiKey,
Label: label,
}
l.Debug().Any("opts", opts).Send()
key, err := store.SaveApiKey(ctx, opts)
if err != nil {
l.Err(err).Msg("Failed to save API key")
utils.WriteError(w, "failed to save api key", http.StatusInternalServerError)
return
}
utils.WriteJSON(w, 201, key)
}
}
func DeleteApiKeyHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
user := middleware.GetUserFromContext(ctx)
if user == nil {
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
idStr := r.URL.Query().Get("id")
if idStr == "" {
utils.WriteError(w, "id is required", http.StatusBadRequest)
return
}
apiKey, err := strconv.Atoi(idStr)
if err != nil {
utils.WriteError(w, "id is invalid", http.StatusBadRequest)
return
}
err = store.DeleteApiKey(ctx, int32(apiKey))
if err != nil {
l.Err(err).Msg("Failed to delete API key")
utils.WriteError(w, "failed to delete api key", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}
func GetApiKeysHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
l.Debug().Msgf("Retrieving user from middleware...")
user := middleware.GetUserFromContext(ctx)
if user == nil {
l.Debug().Msgf("Could not retrieve user from middleware")
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
l.Debug().Msgf("Retrieved user '%s' from middleware", user.Username)
apiKeys, err := store.GetApiKeysByUserID(ctx, user.ID)
if err != nil {
l.Err(err).Msg("Failed to retrieve API keys")
utils.WriteError(w, "failed to retrieve api keys", http.StatusInternalServerError)
return
}
utils.WriteJSON(w, http.StatusOK, apiKeys)
}
}
func UpdateApiKeyLabelHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
user := middleware.GetUserFromContext(ctx)
if user == nil {
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
idStr := r.URL.Query().Get("id")
if idStr == "" {
utils.WriteError(w, "id is required", http.StatusBadRequest)
return
}
apiKeyID, err := strconv.Atoi(idStr)
if err != nil {
utils.WriteError(w, "id is invalid", http.StatusBadRequest)
return
}
label := r.FormValue("label")
if label == "" {
utils.WriteError(w, "label is required", http.StatusBadRequest)
return
}
err = store.UpdateApiKeyLabel(ctx, db.UpdateApiKeyLabelOpts{
UserID: user.ID,
ID: int32(apiKeyID),
Label: label,
})
if err != nil {
l.Err(err).Msg("Failed to update API key label")
utils.WriteError(w, "failed to update api key label", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
}

149
engine/handlers/auth.go Normal file
View file

@ -0,0 +1,149 @@
package handlers
import (
"net/http"
"strings"
"time"
"github.com/gabehf/koito/engine/middleware"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
func LoginHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
ctx := r.Context()
l.Debug().Msg("Recieved login request")
r.ParseForm()
username := r.FormValue("username")
password := r.FormValue("password")
if username == "" || password == "" {
utils.WriteError(w, "username and password are required", http.StatusBadRequest)
return
}
user, err := store.GetUserByUsername(ctx, username)
if err != nil {
l.Err(err).Msg("Error searching for user in database")
utils.WriteError(w, "internal server error", http.StatusInternalServerError)
return
} else if user == nil {
utils.WriteError(w, "username or password is incorrect", http.StatusBadRequest)
return
}
err = bcrypt.CompareHashAndPassword(user.Password, []byte(password))
if err != nil {
utils.WriteError(w, "username or password is incorrect", http.StatusBadRequest)
return
}
keepSignedIn := false
expiresAt := time.Now().Add(1 * 24 * time.Hour)
if strings.ToLower(r.FormValue("remember_me")) == "true" {
keepSignedIn = true
expiresAt = time.Now().Add(30 * 24 * time.Hour)
}
session, err := store.SaveSession(ctx, user.ID, expiresAt, keepSignedIn)
if err != nil {
l.Err(err).Msg("Failed to create session")
utils.WriteError(w, "failed to create session", http.StatusInternalServerError)
return
}
cookie := &http.Cookie{
Name: "koito_session",
Value: session.ID.String(),
Path: "/",
HttpOnly: true,
Secure: false,
}
if keepSignedIn {
cookie.Expires = expiresAt
}
http.SetCookie(w, cookie)
w.WriteHeader(http.StatusNoContent)
}
}
func LogoutHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
cookie, err := r.Cookie("koito_session")
if err == nil {
sid, err := uuid.Parse(cookie.Value)
if err != nil {
utils.WriteError(w, "session cookie is invalid", http.StatusUnauthorized)
return
}
err = store.DeleteSession(r.Context(), sid)
if err != nil {
l.Err(err).Msg("Failed to delete session")
utils.WriteError(w, "internal server error", http.StatusInternalServerError)
return
}
}
// Clear the cookie
http.SetCookie(w, &http.Cookie{
Name: "koito_session",
Value: "",
Path: "/",
HttpOnly: true,
MaxAge: -1, // expire immediately
})
w.WriteHeader(http.StatusNoContent)
}
}
func MeHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
u := middleware.GetUserFromContext(ctx)
if u == nil {
l.Debug().Msg("Invalid user retrieved from context")
}
utils.WriteJSON(w, 200, u)
}
}
func UpdateUserHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
u := middleware.GetUserFromContext(ctx)
if u == nil {
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
r.ParseForm()
username := r.FormValue("username")
password := r.FormValue("password")
l.Debug().Msgf("Recieved update request for user with id %d", u.ID)
err := store.UpdateUser(ctx, db.UpdateUserOpts{
ID: u.ID,
Username: username,
Password: password,
})
if err != nil {
l.Err(err).Msg("Failed to update user")
utils.WriteError(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusNoContent)
}
}

137
engine/handlers/delete.go Normal file
View file

@ -0,0 +1,137 @@
package handlers
import (
"net/http"
"strconv"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
// DeleteTrackHandler deletes a track by its ID.
func DeleteTrackHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
trackIDStr := r.URL.Query().Get("id")
if trackIDStr == "" {
utils.WriteError(w, "track_id must be provided", http.StatusBadRequest)
return
}
trackID, err := strconv.Atoi(trackIDStr)
if err != nil {
utils.WriteError(w, "invalid id", http.StatusBadRequest)
return
}
err = store.DeleteTrack(ctx, int32(trackID))
if err != nil {
l.Err(err).Msg("Failed to delete track")
utils.WriteError(w, "failed to delete track", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}
// DeleteTrackHandler deletes a track by its ID.
func DeleteListenHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
trackIDStr := r.URL.Query().Get("track_id")
if trackIDStr == "" {
utils.WriteError(w, "track_id must be provided", http.StatusBadRequest)
return
}
trackID, err := strconv.Atoi(trackIDStr)
if err != nil {
utils.WriteError(w, "invalid id", http.StatusBadRequest)
return
}
unixStr := r.URL.Query().Get("unix")
if trackIDStr == "" {
utils.WriteError(w, "unix timestamp must be provided", http.StatusBadRequest)
return
}
unix, err := strconv.ParseInt(unixStr, 10, 64)
if err != nil {
utils.WriteError(w, "invalid unix timestamp", http.StatusBadRequest)
return
}
err = store.DeleteListen(ctx, int32(trackID), time.Unix(unix, 0))
if err != nil {
l.Err(err).Msg("Failed to delete listen")
utils.WriteError(w, "failed to delete listen", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}
// DeleteArtistHandler deletes an artist by its ID.
func DeleteArtistHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
artistIDStr := r.URL.Query().Get("id")
if artistIDStr == "" {
utils.WriteError(w, "id must be provided", http.StatusBadRequest)
return
}
artistID, err := strconv.Atoi(artistIDStr)
if err != nil {
utils.WriteError(w, "invalid id", http.StatusBadRequest)
return
}
err = store.DeleteArtist(ctx, int32(artistID))
if err != nil {
l.Err(err).Msg("Failed to delete artist")
utils.WriteError(w, "failed to delete artist", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}
// DeleteAlbumHandler deletes an album by its ID.
func DeleteAlbumHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
albumIDStr := r.URL.Query().Get("id")
if albumIDStr == "" {
utils.WriteError(w, "id must be provided", http.StatusBadRequest)
return
}
albumID, err := strconv.Atoi(albumIDStr)
if err != nil {
utils.WriteError(w, "invalid id", http.StatusBadRequest)
return
}
err = store.DeleteAlbum(ctx, int32(albumID))
if err != nil {
l.Err(err).Msg("Failed to delete album")
utils.WriteError(w, "failed to delete album", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}

View file

@ -0,0 +1,28 @@
package handlers
import (
"net/http"
"strconv"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/utils"
)
func GetAlbumHandler(store db.DB) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
idStr := r.URL.Query().Get("id")
id, err := strconv.Atoi(idStr)
if err != nil {
utils.WriteError(w, "id is invalid", 400)
return
}
album, err := store.GetAlbum(r.Context(), db.GetAlbumOpts{ID: int32(id)})
if err != nil {
utils.WriteError(w, "album with specified id could not be found", http.StatusNotFound)
return
}
utils.WriteJSON(w, http.StatusOK, album)
}
}

View file

@ -0,0 +1,28 @@
package handlers
import (
"net/http"
"strconv"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/utils"
)
func GetArtistHandler(store db.DB) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
idStr := r.URL.Query().Get("id")
id, err := strconv.Atoi(idStr)
if err != nil {
utils.WriteError(w, "id is invalid", 400)
return
}
artist, err := store.GetArtist(r.Context(), db.GetArtistOpts{ID: int32(id)})
if err != nil {
utils.WriteError(w, "artist with specified id could not be found", http.StatusNotFound)
return
}
utils.WriteJSON(w, http.StatusOK, artist)
}
}

View file

@ -0,0 +1,65 @@
package handlers
import (
"net/http"
"strconv"
"strings"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
func GetListenActivityHandler(store db.DB) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
rangeStr := r.URL.Query().Get("range")
_range, _ := strconv.Atoi(rangeStr)
monthStr := r.URL.Query().Get("month")
month, _ := strconv.Atoi(monthStr)
yearStr := r.URL.Query().Get("year")
year, _ := strconv.Atoi(yearStr)
artistIdStr := r.URL.Query().Get("artist_id")
artistId, _ := strconv.Atoi(artistIdStr)
albumIdStr := r.URL.Query().Get("album_id")
albumId, _ := strconv.Atoi(albumIdStr)
trackIdStr := r.URL.Query().Get("track_id")
trackId, _ := strconv.Atoi(trackIdStr)
var step db.StepInterval
switch strings.ToLower(r.URL.Query().Get("step")) {
case "day":
step = db.StepDay
case "week":
step = db.StepWeek
case "month":
step = db.StepMonth
case "year":
step = db.StepYear
default:
l.Debug().Msgf("Using default value '%s' for step", db.StepDefault)
step = db.StepDay
}
opts := db.ListenActivityOpts{
Step: step,
Range: _range,
Month: month,
Year: year,
AlbumID: int32(albumId),
ArtistID: int32(artistId),
TrackID: int32(trackId),
}
activity, err := store.GetListenActivity(r.Context(), opts)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, err.Error(), 500)
return
}
utils.WriteJSON(w, http.StatusOK, activity)
}
}

View file

@ -0,0 +1,23 @@
package handlers
import (
"net/http"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
func GetListensHandler(store db.DB) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
opts := OptsFromRequest(r)
listens, err := store.GetListensPaginated(r.Context(), opts)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "failed to get listens: "+err.Error(), 400)
return
}
utils.WriteJSON(w, http.StatusOK, listens)
}
}

View file

@ -0,0 +1,23 @@
package handlers
import (
"net/http"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
func GetTopAlbumsHandler(store db.DB) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
opts := OptsFromRequest(r)
albums, err := store.GetTopAlbumsPaginated(r.Context(), opts)
if err != nil {
l.Err(err).Msg("Failed to get top albums")
utils.WriteError(w, "failed to get albums", 400)
return
}
utils.WriteJSON(w, http.StatusOK, albums)
}
}

View file

@ -0,0 +1,23 @@
package handlers
import (
"net/http"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
func GetTopArtistsHandler(store db.DB) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
opts := OptsFromRequest(r)
artists, err := store.GetTopArtistsPaginated(r.Context(), opts)
if err != nil {
l.Err(err).Msg("Failed to get top artists")
utils.WriteError(w, "failed to get artists", 400)
return
}
utils.WriteJSON(w, http.StatusOK, artists)
}
}

View file

@ -0,0 +1,23 @@
package handlers
import (
"net/http"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
func GetTopTracksHandler(store db.DB) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
opts := OptsFromRequest(r)
tracks, err := store.GetTopTracksPaginated(r.Context(), opts)
if err != nil {
l.Err(err).Msg("Failed to get top tracks")
utils.WriteError(w, "failed to get tracks", 400)
return
}
utils.WriteJSON(w, http.StatusOK, tracks)
}
}

View file

@ -0,0 +1,31 @@
package handlers
import (
"net/http"
"strconv"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
func GetTrackHandler(store db.DB) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
idStr := r.URL.Query().Get("id")
id, err := strconv.Atoi(idStr)
if err != nil {
utils.WriteError(w, "id is invalid", 400)
return
}
track, err := store.GetTrack(r.Context(), db.GetTrackOpts{ID: int32(id)})
if err != nil {
l.Err(err).Msg("Failed to get top albums")
utils.WriteError(w, "track with specified id could not be found", http.StatusNotFound)
return
}
utils.WriteJSON(w, http.StatusOK, track)
}
}

View file

@ -0,0 +1,77 @@
// package handlers implements route handlers
package handlers
import (
"net/http"
"strconv"
"strings"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
)
const defaultLimitSize = 100
const maximumLimit = 500
func OptsFromRequest(r *http.Request) db.GetItemsOpts {
l := logger.FromContext(r.Context())
limitStr := r.URL.Query().Get("limit")
limit, err := strconv.Atoi(limitStr)
if err != nil {
l.Debug().Msgf("query parameter 'limit' not specified, using default %d", defaultLimitSize)
limit = defaultLimitSize
}
if limit > maximumLimit {
l.Debug().Msgf("limit must not be greater than %d, using default %d", maximumLimit, defaultLimitSize)
limit = defaultLimitSize
}
pageStr := r.URL.Query().Get("page")
page, _ := strconv.Atoi(pageStr)
if page < 1 {
page = 1
}
weekStr := r.URL.Query().Get("week")
week, _ := strconv.Atoi(weekStr)
monthStr := r.URL.Query().Get("month")
month, _ := strconv.Atoi(monthStr)
yearStr := r.URL.Query().Get("year")
year, _ := strconv.Atoi(yearStr)
artistIdStr := r.URL.Query().Get("artist_id")
artistId, _ := strconv.Atoi(artistIdStr)
albumIdStr := r.URL.Query().Get("album_id")
albumId, _ := strconv.Atoi(albumIdStr)
trackIdStr := r.URL.Query().Get("track_id")
trackId, _ := strconv.Atoi(trackIdStr)
var period db.Period
switch strings.ToLower(r.URL.Query().Get("period")) {
case "day":
period = db.PeriodDay
case "week":
period = db.PeriodWeek
case "month":
period = db.PeriodMonth
case "year":
period = db.PeriodYear
case "all_time":
period = db.PeriodAllTime
default:
l.Debug().Msgf("Using default value '%s' for period", db.PeriodDay)
period = db.PeriodDay
}
return db.GetItemsOpts{
Limit: limit,
Period: period,
Page: page,
Week: week,
Month: month,
Year: year,
ArtistID: artistId,
AlbumID: albumId,
TrackID: trackId,
}
}

10
engine/handlers/health.go Normal file
View file

@ -0,0 +1,10 @@
package handlers
import "net/http"
func HealthHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ready"}`))
}
}

View file

@ -0,0 +1,208 @@
package handlers
import (
"bytes"
"net/http"
"os"
"path"
"path/filepath"
"sync"
"github.com/gabehf/koito/internal/catalog"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
)
func ImageHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
size := chi.URLParam(r, "size")
filename := chi.URLParam(r, "filename")
imageSize, err := catalog.ParseImageSize(size)
if err != nil {
w.WriteHeader(http.StatusNotFound)
return
}
imgid, err := uuid.Parse(filename)
if err != nil {
serveDefaultImage(w, r, imageSize)
return
}
desiredImgPath := filepath.Join(cfg.ConfigDir(), catalog.ImageCacheDir, size, filepath.Clean(filename))
if _, err := os.Stat(desiredImgPath); os.IsNotExist(err) {
l.Debug().Msg("Image not found in desired size")
// file doesn't exist in desired size
fullSizePath := filepath.Join(cfg.ConfigDir(), catalog.ImageCacheDir, string(catalog.ImageSizeFull), filepath.Clean(filename))
largeSizePath := filepath.Join(cfg.ConfigDir(), catalog.ImageCacheDir, string(catalog.ImageSizeLarge), filepath.Clean(filename))
// check if file exists at either full or large size
// note: have to check both in case a user switched caching full size on and off
// which would result in cache misses from source changing
var sourcePath string
if _, err = os.Stat(fullSizePath); os.IsNotExist(err) {
if _, err = os.Stat(largeSizePath); os.IsNotExist(err) {
l.Warn().Msgf("Could not find requested image %s. If this image is tied to an album or artist, it should be replaced", imgid.String())
serveDefaultImage(w, r, imageSize)
return
} else if err != nil {
// non-not found error for full file
l.Err(err).Msg("Failed to access source image file")
w.WriteHeader(http.StatusInternalServerError)
return
}
sourcePath = largeSizePath
} else if err != nil {
// non-not found error for full file
l.Err(err).Msg("Failed to access source image file")
w.WriteHeader(http.StatusInternalServerError)
return
} else {
sourcePath = fullSizePath
}
// source size file was found
// create and cache image at desired size
imageBuf, err := os.ReadFile(sourcePath)
if err != nil {
l.Err(err).Msg("Failed to read source image file")
w.WriteHeader(http.StatusInternalServerError)
return
}
err = catalog.CompressAndSaveImage(r.Context(), imgid.String(), imageSize, bytes.NewReader(imageBuf))
if err != nil {
l.Err(err).Msg("Failed to save compressed image to cache")
}
} else if err != nil {
// non-not found error for desired file
l.Err(err).Msg("Failed to access desired image file")
w.WriteHeader(http.StatusInternalServerError)
return
}
// Serve image
http.ServeFile(w, r, desiredImgPath)
}
}
func serveDefaultImage(w http.ResponseWriter, r *http.Request, size catalog.ImageSize) {
var lock sync.Mutex
l := logger.FromContext(r.Context())
defaultImagePath := filepath.Join(cfg.ConfigDir(), catalog.ImageCacheDir, string(size), "default_img")
if _, err := os.Stat(defaultImagePath); os.IsNotExist(err) {
l.Debug().Msg("Default image does not exist in cache at desired size")
defaultImagePath := filepath.Join(catalog.SourceImageDir(), "default_img")
if _, err = os.Stat(defaultImagePath); os.IsNotExist(err) {
l.Debug().Msg("Default image does not exist in cache, attempting to move...")
err = os.MkdirAll(filepath.Dir(defaultImagePath), 0755)
if err != nil {
l.Err(err).Msg("Error when attempting to create image_cache/full dir")
w.WriteHeader(http.StatusInternalServerError)
return
}
lock.Lock()
utils.CopyFile(path.Join("assets", "default_img"), defaultImagePath)
lock.Unlock()
} else if err != nil {
// non-not found error
l.Error().Err(err).Msg("Error when attempting to read default image in cache")
w.WriteHeader(http.StatusInternalServerError)
return
}
// default_img does (or now does) exist in cache at full size
file, err := os.Open(path.Join(catalog.SourceImageDir(), "default_img"))
if err != nil {
l.Err(err).Msg("Error when reading default image from source dir")
w.WriteHeader(http.StatusInternalServerError)
return
}
err = catalog.CompressAndSaveImage(r.Context(), "default_img", size, file)
if err != nil {
l.Err(err).Msg("Error when caching default img at desired size")
w.WriteHeader(http.StatusInternalServerError)
return
}
} else if err != nil {
// non-not found error
l.Error().Err(err).Msg("Error when attempting to read default image in cache")
w.WriteHeader(http.StatusInternalServerError)
return
}
// serve default_img at desired size
http.ServeFile(w, r, path.Join(cfg.ConfigDir(), catalog.ImageCacheDir, string(size), "default_img"))
}
// func SearchMissingAlbumImagesHandler(store db.DB) http.HandlerFunc {
// return func(w http.ResponseWriter, r *http.Request) {
// ctx := r.Context()
// l := logger.FromContext(ctx)
// l.Info().Msg("Beginning search for albums with missing images")
// go func() {
// defer func() {
// if r := recover(); r != nil {
// l.Error().Interface("recover", r).Msg("Panic when searching for missing album images")
// }
// }()
// ctx := logger.NewContext(l)
// from := int32(0)
// count := 0
// for {
// albums, err := store.AlbumsWithoutImages(ctx, from)
// if errors.Is(err, pgx.ErrNoRows) {
// break
// } else if err != nil {
// l.Err(err).Msg("Failed to search for missing images")
// return
// }
// l.Debug().Msgf("Queried %d albums on page %d", len(albums), from)
// if len(albums) < 1 {
// break
// }
// for _, a := range albums {
// l.Debug().Msgf("Searching images for album %s", a.Title)
// img, err := imagesrc.GetAlbumImages(ctx, imagesrc.AlbumImageOpts{
// Artists: utils.FlattenSimpleArtistNames(a.Artists),
// Album: a.Title,
// ReleaseMbzID: a.MbzID,
// })
// if err == nil && img != "" {
// l.Debug().Msg("Image found! Downloading...")
// imgid, err := catalog.DownloadAndCacheImage(ctx, img)
// if err != nil {
// l.Err(err).Msgf("Failed to download image for %s", a.Title)
// continue
// }
// err = store.UpdateAlbum(ctx, db.UpdateAlbumOpts{
// ID: a.ID,
// Image: imgid,
// })
// if err != nil {
// l.Err(err).Msgf("Failed to update image for %s", a.Title)
// continue
// }
// l.Info().Msgf("Found new album image for %s", a.Title)
// count++
// }
// if err != nil {
// l.Err(err).Msgf("Failed to get album images for %s", a.Title)
// }
// }
// from = albums[len(albums)-1].ID
// }
// l.Info().Msgf("Completed search, finding %d new images", count)
// }()
// w.WriteHeader(http.StatusOK)
// }
// }

View file

@ -0,0 +1,278 @@
package handlers
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/gabehf/koito/engine/middleware"
"github.com/gabehf/koito/internal/catalog"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
mbz "github.com/gabehf/koito/internal/mbz"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
"github.com/rs/zerolog"
"golang.org/x/sync/singleflight"
)
type LbzListenType string
const (
ListenTypeSingle LbzListenType = "single"
ListenTypePlayingNow LbzListenType = "playing_now"
ListenTypeImport LbzListenType = "import"
)
type LbzSubmitListenRequest struct {
ListenType LbzListenType `json:"listen_type,omitempty"`
Payload []LbzSubmitListenPayload `json:"payload,omitempty"`
}
type LbzSubmitListenPayload struct {
ListenedAt int64 `json:"listened_at,omitempty"`
TrackMeta LbzTrackMeta `json:"track_metadata"`
}
type LbzTrackMeta struct {
ArtistName string `json:"artist_name"` // required
TrackName string `json:"track_name"` // required
ReleaseName string `json:"release_name,omitempty"`
AdditionalInfo LbzAdditionalInfo `json:"additional_info,omitempty"`
}
type LbzAdditionalInfo struct {
MediaPlayer string `json:"media_player,omitempty"`
SubmissionClient string `json:"submission_client,omitempty"`
SubmissionClientVersion string `json:"submission_client_version,omitempty"`
ReleaseMBID string `json:"release_mbid,omitempty"`
ReleaseGroupMBID string `json:"release_group_mbid,omitempty"`
ArtistMBIDs []string `json:"artist_mbids,omitempty"`
ArtistNames []string `json:"artist_names,omitempty"`
RecordingMBID string `json:"recording_mbid,omitempty"`
DurationMs int32 `json:"duration_ms,omitempty"`
Duration int32 `json:"duration,omitempty"`
Tags []string `json:"tags,omitempty"`
AlbumArtist string `json:"albumartist,omitempty"`
}
const (
maxListensPerRequest = 1000
)
var sfGroup singleflight.Group
func LbzSubmitListenHandler(store db.DB, mbzc mbz.MusicBrainzCaller) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
var req LbzSubmitListenRequest
requestBytes, err := io.ReadAll(r.Body)
if err != nil {
utils.WriteError(w, "failed to read request body", http.StatusBadRequest)
return
}
if err := json.NewDecoder(bytes.NewBuffer(requestBytes)).Decode(&req); err != nil {
l.Debug().Err(err).Msg("Failed to decode request")
utils.WriteError(w, "failed to decode request", http.StatusBadRequest)
return
}
u := middleware.GetUserFromContext(r.Context())
if u == nil {
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
l.Debug().Any("request_body", req).Msg("Recieved request")
if len(req.Payload) < 1 {
l.Error().Msg("Payload is nil")
utils.WriteError(w, "payload is nil", http.StatusBadRequest)
return
}
if len(req.Payload) > maxListensPerRequest {
l.Error().Msg("Payload exceeds max listens per request")
utils.WriteError(w, "payload exceeds max listens per request", http.StatusBadRequest)
return
}
if len(req.Payload) != 1 && req.ListenType != "import" {
l.Error().Msg("Payload must only contain one listen for non-import requests")
utils.WriteError(w, "payload must only contain one listen for non-import requests", http.StatusBadRequest)
return
}
for _, payload := range req.Payload {
if payload.TrackMeta.ArtistName == "" || payload.TrackMeta.TrackName == "" {
l.Error().Msg("Artist name or track name are missing, unable to process listen")
utils.WriteError(w, "Artist name or track name are missing", http.StatusBadRequest)
return
}
if req.ListenType != ListenTypePlayingNow && req.ListenType != ListenTypeSingle && req.ListenType != ListenTypeImport {
l.Debug().Msg("No listen type provided, assuming 'single'")
req.ListenType = "single"
}
artistMbzIDs, err := utils.ParseUUIDSlice(payload.TrackMeta.AdditionalInfo.ArtistMBIDs)
if err != nil {
l.Debug().Err(err).Msg("Failed to parse one or more uuids")
}
rgMbzID, err := uuid.Parse(payload.TrackMeta.AdditionalInfo.ReleaseGroupMBID)
if err != nil {
rgMbzID = uuid.Nil
}
releaseMbzID, err := uuid.Parse(payload.TrackMeta.AdditionalInfo.ReleaseMBID)
if err != nil {
releaseMbzID = uuid.Nil
}
recordingMbzID, err := uuid.Parse(payload.TrackMeta.AdditionalInfo.RecordingMBID)
if err != nil {
recordingMbzID = uuid.Nil
}
var client string
if payload.TrackMeta.AdditionalInfo.MediaPlayer != "" {
client = payload.TrackMeta.AdditionalInfo.MediaPlayer
} else if payload.TrackMeta.AdditionalInfo.SubmissionClient != "" {
client = payload.TrackMeta.AdditionalInfo.SubmissionClient
}
var duration int32
if payload.TrackMeta.AdditionalInfo.Duration != 0 {
duration = payload.TrackMeta.AdditionalInfo.Duration
} else if payload.TrackMeta.AdditionalInfo.DurationMs != 0 {
duration = payload.TrackMeta.AdditionalInfo.DurationMs / 1000
}
var listenedAt = time.Now()
if payload.ListenedAt != 0 {
listenedAt = time.Unix(payload.ListenedAt, 0)
}
opts := catalog.SubmitListenOpts{
MbzCaller: mbzc,
ArtistNames: payload.TrackMeta.AdditionalInfo.ArtistNames,
Artist: payload.TrackMeta.ArtistName,
ArtistMbzIDs: artistMbzIDs,
TrackTitle: payload.TrackMeta.TrackName,
RecordingMbzID: recordingMbzID,
ReleaseTitle: payload.TrackMeta.ReleaseName,
ReleaseMbzID: releaseMbzID,
ReleaseGroupMbzID: rgMbzID,
Duration: duration,
Time: listenedAt,
UserID: u.ID,
Client: client,
}
if req.ListenType == ListenTypePlayingNow {
opts.SkipSaveListen = true
}
_, err, shared := sfGroup.Do(buildCaolescingKey(payload), func() (interface{}, error) {
return 0, catalog.SubmitListen(r.Context(), store, opts)
})
if shared {
l.Info().Msg("Duplicate requests detected; results were coalesced")
}
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"status\": \"internal server error\"}"))
}
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"status\": \"ok\"}"))
if cfg.LbzRelayEnabled() {
go doLbzRelay(requestBytes, l)
}
}
}
func doLbzRelay(requestBytes []byte, l *zerolog.Logger) {
defer func() {
if r := recover(); r != nil {
l.Error().Interface("recover", r).Msg("Panic in doLbzRelay")
}
}()
const (
maxRetryDuration = 10 * time.Second
initialBackoff = 1 * time.Second
maxBackoff = 4 * time.Second
)
req, err := http.NewRequest("POST", cfg.LbzRelayUrl()+"/submit-listens", bytes.NewBuffer(requestBytes))
if err != nil {
l.Error().Msg("Failed to build ListenBrainz relay request")
l.Error().Err(err).Send()
return
}
req.Header.Add("Authorization", "Token "+cfg.LbzRelayToken())
req.Header.Add("Content-Type", "application/json")
client := &http.Client{
Timeout: 5 * time.Second,
}
var resp *http.Response
var body []byte
start := time.Now()
backoff := initialBackoff
for {
resp, err = client.Do(req)
if err != nil {
l.Err(err).Msg("Failed to send ListenBrainz relay request")
return
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
l.Info().Msg("Successfully relayed ListenBrainz submission")
return
}
body, _ = io.ReadAll(resp.Body)
if resp.StatusCode >= 500 && time.Since(start)+backoff <= maxRetryDuration {
l.Warn().
Int("status", resp.StatusCode).
Str("response", string(body)).
Msg("Retryable server error from ListenBrainz relay, retrying...")
time.Sleep(backoff)
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
// 4xx status or timeout exceeded
l.Warn().
Int("status", resp.StatusCode).
Str("response", string(body)).
Msg("Non-2XX response from ListenBrainz relay")
return
}
}
func buildCaolescingKey(p LbzSubmitListenPayload) string {
// the key not including the listen_type introduces the very rare possibility of a playing_now
// request taking precedence over a single, meaning that a listen will not be logged when it
// should, however that would require a playing_now request to fire a few seconds before a 'single'
// of the same track, which should never happen outside of misbehaving clients
//
// this could be fixed by restructuring the database inserts for idempotency, which would
// eliminate the need to coalesce responses, however i'm not gonna do that right now
return fmt.Sprintf("%s:%s:%s", p.TrackMeta.ArtistName, p.TrackMeta.TrackName, p.TrackMeta.ReleaseName)
}

View file

@ -0,0 +1,41 @@
package handlers
import (
"net/http"
"github.com/gabehf/koito/engine/middleware"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
type LbzValidateResponse struct {
Code int `json:"code"`
Error string `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Valid bool `json:"valid,omitempty"`
UserName string `json:"user_name,omitempty"`
}
func LbzValidateTokenHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
l.Debug().Msg("Validating user token...")
u := middleware.GetUserFromContext(ctx)
var response LbzValidateResponse
if u == nil {
response.Code = http.StatusUnauthorized
response.Error = "Incorrect Authorization"
w.WriteHeader(http.StatusUnauthorized)
utils.WriteJSON(w, http.StatusOK, response)
} else {
response.Code = 200
response.Message = "Token valid."
response.Valid = true
response.UserName = u.Username
utils.WriteJSON(w, http.StatusOK, response)
}
}
}

97
engine/handlers/merge.go Normal file
View file

@ -0,0 +1,97 @@
package handlers
import (
"net/http"
"strconv"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
func MergeTracksHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
fromidStr := r.URL.Query().Get("from_id")
fromId, err := strconv.Atoi(fromidStr)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "from_id is invalid", 400)
return
}
toidStr := r.URL.Query().Get("to_id")
toId, err := strconv.Atoi(toidStr)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "to_id is invalid", 400)
return
}
err = store.MergeTracks(r.Context(), int32(fromId), int32(toId))
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "Failed to merge tracks: "+err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}
func MergeReleaseGroupsHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
fromidStr := r.URL.Query().Get("from_id")
fromId, err := strconv.Atoi(fromidStr)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "from_id is invalid", 400)
return
}
toidStr := r.URL.Query().Get("to_id")
toId, err := strconv.Atoi(toidStr)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "to_id is invalid", 400)
return
}
err = store.MergeAlbums(r.Context(), int32(fromId), int32(toId))
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "Failed to merge albums: "+err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}
func MergeArtistsHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
fromidStr := r.URL.Query().Get("from_id")
fromId, err := strconv.Atoi(fromidStr)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "from_id is invalid", 400)
return
}
toidStr := r.URL.Query().Get("to_id")
toId, err := strconv.Atoi(toidStr)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "to_id is invalid", 400)
return
}
err = store.MergeArtists(r.Context(), int32(fromId), int32(toId))
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "Failed to merge artists: "+err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}
}

View file

@ -0,0 +1,178 @@
package handlers
import (
"io"
"net/http"
"strconv"
"strings"
"github.com/gabehf/koito/internal/catalog"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
)
type ReplaceImageResponse struct {
Success bool `json:"success"`
Image string `json:"image"`
Message string `json:"message,omitempty"`
}
func ReplaceImageHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
artistIdStr := r.FormValue("artist_id")
artistId, _ := strconv.Atoi(artistIdStr)
albumIdStr := r.FormValue("album_id")
albumId, _ := strconv.Atoi(albumIdStr)
if artistId != 0 && albumId != 0 {
utils.WriteError(w, "Only one of artist_id and album_id can be set", http.StatusBadRequest)
return
} else if artistId == 0 && albumId == 0 {
utils.WriteError(w, "One of artist_id and album_id must be set", http.StatusBadRequest)
return
}
var oldImage *uuid.UUID
if artistId != 0 {
a, err := store.GetArtist(ctx, db.GetArtistOpts{
ID: int32(artistId),
})
if err != nil {
utils.WriteError(w, "Artist with specified id could not be found", http.StatusBadRequest)
return
}
oldImage = a.Image
} else if albumId != 0 {
a, err := store.GetAlbum(ctx, db.GetAlbumOpts{
ID: int32(albumId),
})
if err != nil {
utils.WriteError(w, "Album with specified id could not be found", http.StatusBadRequest)
return
}
oldImage = a.Image
}
l.Debug().Msgf("Getting image from request...")
var id uuid.UUID
var err error
fileUrl := r.FormValue("image_url")
if fileUrl != "" {
l.Debug().Msg("Image identified as remote file")
err = catalog.ValidateImageURL(fileUrl)
if err != nil {
utils.WriteError(w, "url is invalid or not an image file", http.StatusBadRequest)
return
}
id = uuid.New()
var dlSize catalog.ImageSize
if cfg.FullImageCacheEnabled() {
dlSize = catalog.ImageSizeFull
} else {
dlSize = catalog.ImageSizeLarge
}
l.Debug().Msg("Downloading album image from source...")
err = catalog.DownloadAndCacheImage(ctx, id, fileUrl, dlSize)
if err != nil {
l.Err(err).Msg("Failed to cache image")
}
} else {
file, _, err := r.FormFile("image")
if err != nil {
utils.WriteError(w, "Invalid file", http.StatusBadRequest)
return
}
defer file.Close()
buf := make([]byte, 512)
if _, err := file.Read(buf); err != nil {
utils.WriteError(w, "Could not read file", http.StatusInternalServerError)
return
}
contentType := http.DetectContentType(buf)
if !strings.HasPrefix(contentType, "image/") {
utils.WriteError(w, "Only image uploads are allowed", http.StatusBadRequest)
return
}
if _, err := file.Seek(0, io.SeekStart); err != nil {
utils.WriteError(w, "Could not seek file", http.StatusInternalServerError)
return
}
l.Debug().Msgf("Saving image to cache...")
id = uuid.New()
var dlSize catalog.ImageSize
if cfg.FullImageCacheEnabled() {
dlSize = catalog.ImageSizeFull
} else {
dlSize = catalog.ImageSizeLarge
}
err = catalog.CompressAndSaveImage(ctx, id.String(), dlSize, file)
if err != nil {
utils.WriteError(w, "Could not save file", http.StatusInternalServerError)
return
}
}
l.Debug().Msgf("Updating database...")
var imgsrc string
if fileUrl != "" {
imgsrc = fileUrl
} else {
imgsrc = catalog.ImageSourceUserUpload
}
if artistId != 0 {
err := store.UpdateArtist(ctx, db.UpdateArtistOpts{
ID: int32(artistId),
Image: id,
ImageSrc: imgsrc,
})
if err != nil {
l.Err(err).Msg("Artist image could not be updated")
utils.WriteError(w, "Artist image could not be updated", http.StatusInternalServerError)
return
}
} else if albumId != 0 {
err := store.UpdateAlbum(ctx, db.UpdateAlbumOpts{
ID: int32(albumId),
Image: id,
ImageSrc: imgsrc,
})
if err != nil {
l.Err(err).Msg("Album image could not be updated")
utils.WriteError(w, "Album image could not be updated", http.StatusInternalServerError)
return
}
}
if oldImage != nil {
l.Debug().Msg("Cleaning up old image file...")
err = catalog.DeleteImage(*oldImage)
if err != nil {
l.Err(err).Msg("Failed to delete old image file")
utils.WriteError(w, "Could not delete old image file", http.StatusInternalServerError)
return
}
}
utils.WriteJSON(w, http.StatusOK, ReplaceImageResponse{
Success: true,
Image: id.String(),
})
}
}

View file

@ -0,0 +1 @@
package handlers_test

47
engine/handlers/search.go Normal file
View file

@ -0,0 +1,47 @@
package handlers
import (
"net/http"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/utils"
)
type SearchResults struct {
Artists []*models.Artist `json:"artists"`
Albums []*models.Album `json:"albums"`
Tracks []*models.Track `json:"tracks"`
}
func SearchHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
q := r.URL.Query().Get("q")
artists, err := store.SearchArtists(ctx, q)
if err != nil {
l.Err(err).Msg("Failed to search for artists")
utils.WriteError(w, "failed to search in database", http.StatusInternalServerError)
return
}
albums, err := store.SearchAlbums(ctx, q)
if err != nil {
l.Err(err).Msg("Failed to search for albums")
utils.WriteError(w, "failed to search in database", http.StatusInternalServerError)
return
}
tracks, err := store.SearchTracks(ctx, q)
if err != nil {
l.Err(err).Msg("Failed to search for tracks")
utils.WriteError(w, "failed to search in database", http.StatusInternalServerError)
return
}
utils.WriteJSON(w, http.StatusOK, SearchResults{
Artists: artists,
Albums: albums,
Tracks: tracks,
})
}
}

77
engine/handlers/stats.go Normal file
View file

@ -0,0 +1,77 @@
package handlers
import (
"net/http"
"strings"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
type StatsResponse struct {
ListenCount int64 `json:"listen_count"`
TrackCount int64 `json:"track_count"`
AlbumCount int64 `json:"album_count"`
ArtistCount int64 `json:"artist_count"`
HoursListened int64 `json:"hours_listened"`
}
func StatsHandler(store db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
var period db.Period
switch strings.ToLower(r.URL.Query().Get("period")) {
case "day":
period = db.PeriodDay
case "week":
period = db.PeriodWeek
case "month":
period = db.PeriodMonth
case "year":
period = db.PeriodYear
case "all_time":
period = db.PeriodAllTime
default:
l.Debug().Msgf("Using default value '%s' for period", db.PeriodDay)
period = db.PeriodDay
}
listens, err := store.CountListens(r.Context(), period)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "failed to get listens: "+err.Error(), http.StatusInternalServerError)
return
}
tracks, err := store.CountTracks(r.Context(), period)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "failed to get listens: "+err.Error(), http.StatusInternalServerError)
return
}
albums, err := store.CountAlbums(r.Context(), period)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "failed to get listens: "+err.Error(), http.StatusInternalServerError)
return
}
artists, err := store.CountArtists(r.Context(), period)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "failed to get listens: "+err.Error(), http.StatusInternalServerError)
return
}
timeListenedS, err := store.CountTimeListened(r.Context(), period)
if err != nil {
l.Err(err).Send()
utils.WriteError(w, "failed to get listens: "+err.Error(), http.StatusInternalServerError)
return
}
utils.WriteJSON(w, http.StatusOK, StatsResponse{
ListenCount: listens,
TrackCount: tracks,
AlbumCount: albums,
ArtistCount: artists,
HoursListened: timeListenedS / 60 / 60,
})
}
}

64
engine/import_test.go Normal file
View file

@ -0,0 +1,64 @@
package engine_test
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/gabehf/koito/engine"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestImportMaloja(t *testing.T) {
src := "../static/maloja_import_test.json"
destDir := filepath.Join(cfg.ConfigDir(), "import")
dest := filepath.Join(destDir, "maloja_import_test.json")
// not going to make the dest dir because engine should make it already
input, err := os.ReadFile(src)
require.NoError(t, err)
require.NoError(t, os.WriteFile(dest, input, os.ModePerm))
engine.RunImporter(logger.Get(), store)
// maloja test import is 38 Magnify Tokyo streams
a, err := store.GetArtist(context.Background(), db.GetArtistOpts{Name: "Magnify Tokyo"})
require.NoError(t, err)
t.Log(a)
assert.Equal(t, "Magnify Tokyo", a.Name)
assert.EqualValues(t, 38, a.ListenCount)
}
func TestImportSpotify(t *testing.T) {
src := "../static/Streaming_History_Audio_spotify_import_test.json"
destDir := filepath.Join(cfg.ConfigDir(), "import")
dest := filepath.Join(destDir, "Streaming_History_Audio_spotify_import_test.json")
// not going to make the dest dir because engine should make it already
input, err := os.ReadFile(src)
require.NoError(t, err)
require.NoError(t, os.WriteFile(dest, input, os.ModePerm))
engine.RunImporter(logger.Get(), store)
a, err := store.GetArtist(context.Background(), db.GetArtistOpts{Name: "The Story So Far"})
require.NoError(t, err)
track, err := store.GetTrack(context.Background(), db.GetTrackOpts{Title: "Clairvoyant", ArtistIDs: []int32{a.ID}})
require.NoError(t, err)
t.Log(track)
assert.Equal(t, "Clairvoyant", track.Title)
// spotify includes duration data, but we only import when reason_end = trackdone
// this is the only track with valid duration data
assert.EqualValues(t, 181, track.Duration)
}

734
engine/long_test.go Normal file
View file

@ -0,0 +1,734 @@
package engine_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"path"
"strings"
"sync"
"testing"
"time"
"github.com/gabehf/koito/engine/handlers"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/models"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var session string
var apikey string
var loginOnce sync.Once
var apikeyOnce sync.Once
func login(t *testing.T) {
loginOnce.Do(func() {
formdata := url.Values{}
formdata.Set("username", cfg.DefaultUsername())
formdata.Set("password", cfg.DefaultPassword())
encoded := formdata.Encode()
resp, err := http.DefaultClient.Post(host()+"/apis/web/v1/login", "application/x-www-form-urlencoded", strings.NewReader(encoded))
respBytes, _ := io.ReadAll(resp.Body)
t.Logf("Login request response: %s - %s", resp.Status, respBytes)
require.NoError(t, err)
require.Len(t, resp.Cookies(), 1)
session = resp.Cookies()[0].Value
require.NotEmpty(t, session)
})
}
func makeAuthRequest(t *testing.T, session, method, endpoint string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest(method, host()+endpoint, body)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: session,
})
t.Logf("Making request to %s with session: %s", endpoint, session)
return http.DefaultClient.Do(req)
}
// Expects a valid session
func getApiKey(t *testing.T, session string) {
apikeyOnce.Do(func() {
resp, err := makeAuthRequest(t, session, "GET", "/apis/web/v1/user/apikeys", nil)
require.NoError(t, err)
var keys []models.ApiKey
err = json.NewDecoder(resp.Body).Decode(&keys)
require.NoError(t, err)
require.GreaterOrEqual(t, len(keys), 1)
apikey = keys[0].Key
})
}
func truncateTestData(t *testing.T) {
err := store.Exec(context.Background(),
`TRUNCATE
artists,
artist_aliases,
tracks,
artist_tracks,
releases,
artist_releases,
release_aliases,
listens
RESTART IDENTITY CASCADE`)
require.NoError(t, err)
}
func doSubmitListens(t *testing.T) {
login(t)
getApiKey(t, session)
truncateTestData(t)
bodies := []string{fmt.Sprintf(`{
"listen_type": "single",
"payload": [
{
"listened_at": %d,
"track_metadata": {
"additional_info": {
"artist_mbids": [
"efc787f0-046f-4a60-beff-77b398c8cdf4"
],
"artist_names": [
"さユり"
],
"duration_ms": 275960,
"recording_mbid": "21524d55-b1f8-45d1-b172-976cba447199",
"release_group_mbid": "3281e0d9-fa44-4337-a8ce-6f264beeae16",
"release_mbid": "eb790e90-0065-4852-b47d-bbeede4aa9fc",
"submission_client": "navidrome",
"submission_client_version": "0.56.1 (fa2cf362)"
},
"artist_name": "さユり",
"release_name": "酸欠少女",
"track_name": "花の塔"
}
}
]
}`, time.Now().Add(-2*time.Hour).Unix()), // yesterday
fmt.Sprintf(`{
"listen_type": "single",
"payload": [
{
"listened_at": %d,
"track_metadata": {
"additional_info": {
"artist_mbids": [
"80b3cb83-b7a3-4f79-ad42-8325cefb3626"
],
"artist_names": [
"キタニタツヤ"
],
"duration_ms": 197270,
"recording_mbid": "4e909c21-e7a8-404d-b75a-0c8c2926efb0",
"release_group_mbid": "89069d92-e495-462c-b189-3431551868ed",
"release_mbid": "e16a49d6-77f3-4d73-b93c-cac855ce6ad5",
"submission_client": "navidrome",
"submission_client_version": "0.56.1 (fa2cf362)"
},
"artist_name": "キタニタツヤ",
"release_name": "Where Our Blue Is",
"track_name": "Where Our Blue Is"
}
}
]
}`, time.Now().Unix()),
fmt.Sprintf(`{
"listen_type": "single",
"payload": [
{
"listened_at": %d,
"track_metadata": {
"additional_info": {
"artist_mbids": [
"1262ab85-308b-46e7-b0b5-91fef8e46b62"
],
"artist_names": [
"ネクライトーキー"
],
"duration_ms": 241560,
"recording_mbid": "8eec4f3f-a059-4217-aad1-fbf82e33e756",
"release_group_mbid": "14f1aff0-dd19-4b42-82dd-720386b6d4c1",
"release_mbid": "7762d7af-7b6c-454f-977e-1b261743e265",
"submission_client": "navidrome",
"submission_client_version": "0.56.1 (fa2cf362)"
},
"artist_name": "ネクライトーキー",
"release_name": "ONE!",
"track_name": "こんがらがった!"
}
}
]
}`, time.Now().Add(-1*time.Hour).Unix())}
for _, body := range bodies {
req, err := http.NewRequest("POST", host()+"/apis/listenbrainz/1/submit-listens", strings.NewReader(body))
require.NoError(t, err)
req.Header.Add("Authorization", fmt.Sprintf("Token %s", apikey))
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
respBytes, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, `{"status": "ok"}`, string(respBytes))
}
}
func TestGetters(t *testing.T) {
t.Run("Submit Listens", doSubmitListens)
// Artist was saved
resp, err := http.DefaultClient.Get(host() + "/apis/web/v1/artist?id=1")
assert.NoError(t, err)
var artist models.Artist
err = json.NewDecoder(resp.Body).Decode(&artist)
require.NoError(t, err)
assert.Equal(t, "さユり", artist.Name)
// Album was saved
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/album?id=1")
assert.NoError(t, err)
var album models.Album
err = json.NewDecoder(resp.Body).Decode(&album)
require.NoError(t, err)
assert.Equal(t, "酸欠少女", album.Title)
// Track was saved
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/track?id=1")
assert.NoError(t, err)
var track models.Track
err = json.NewDecoder(resp.Body).Decode(&track)
require.NoError(t, err)
assert.Equal(t, "花の塔", track.Title)
// Listen was saved
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/listens")
assert.NoError(t, err)
var listens db.PaginatedResponse[models.Listen]
err = json.NewDecoder(resp.Body).Decode(&listens)
require.NoError(t, err)
require.Len(t, listens.Items, 3)
assert.EqualValues(t, 2, listens.Items[0].Track.ID)
assert.Equal(t, "Where Our Blue Is", listens.Items[0].Track.Title)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/top-artists")
assert.NoError(t, err)
var artists db.PaginatedResponse[models.Artist]
err = json.NewDecoder(resp.Body).Decode(&artists)
require.NoError(t, err)
require.Len(t, artists.Items, 3)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/top-albums")
assert.NoError(t, err)
var albums db.PaginatedResponse[models.Album]
err = json.NewDecoder(resp.Body).Decode(&albums)
require.NoError(t, err)
require.Len(t, albums.Items, 3)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/top-tracks")
assert.NoError(t, err)
var tracks db.PaginatedResponse[models.Track]
err = json.NewDecoder(resp.Body).Decode(&tracks)
require.NoError(t, err)
require.Len(t, tracks.Items, 3)
truncateTestData(t)
}
func TestMerge(t *testing.T) {
t.Run("Submit Listens", doSubmitListens)
resp, err := makeAuthRequest(t, session, "POST", "/apis/web/v1/merge/tracks?from_id=1&to_id=2", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/track?id=2")
require.NoError(t, err)
var track models.Track
err = json.NewDecoder(resp.Body).Decode(&track)
require.NoError(t, err)
assert.EqualValues(t, 2, track.ListenCount)
truncateTestData(t)
t.Run("Submit Listens", doSubmitListens)
resp, err = makeAuthRequest(t, session, "POST", "/apis/web/v1/merge/artists?from_id=1&to_id=2", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/artist?id=2")
require.NoError(t, err)
var artist models.Artist
err = json.NewDecoder(resp.Body).Decode(&artist)
require.NoError(t, err)
assert.EqualValues(t, 2, artist.ListenCount)
truncateTestData(t)
t.Run("Submit Listens", doSubmitListens)
resp, err = makeAuthRequest(t, session, "POST", "/apis/web/v1/merge/albums?from_id=1&to_id=2", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/album?id=2")
require.NoError(t, err)
var album models.Album
err = json.NewDecoder(resp.Body).Decode(&album)
require.NoError(t, err)
assert.EqualValues(t, 2, album.ListenCount)
truncateTestData(t)
}
func TestValidateToken(t *testing.T) {
login(t)
getApiKey(t, session)
req, err := http.NewRequest("GET", host()+"/apis/listenbrainz/1/validate-token", nil)
require.NoError(t, err)
req.Header.Add("Authorization", fmt.Sprintf("Token %s", apikey))
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
var actual handlers.LbzValidateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&actual))
t.Log(actual)
var expected handlers.LbzValidateResponse
expected.Code = 200
expected.Message = "Token valid."
expected.Valid = true
expected.UserName = "test"
assert.True(t, assert.ObjectsAreEqual(expected, actual))
req, err = http.NewRequest("GET", host()+"/apis/listenbrainz/1/validate-token", nil)
require.NoError(t, err)
req.Header.Add("Authorization", "Token thisisasuperinvalidtoken")
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode)
req, err = http.NewRequest("GET", host()+"/apis/listenbrainz/1/validate-token", nil)
require.NoError(t, err)
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode)
}
func TestDelete(t *testing.T) {
t.Run("Submit Listens", doSubmitListens)
resp, err := makeAuthRequest(t, session, "DELETE", "/apis/web/v1/artist?id=1", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/artist?id=1")
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
resp, err = makeAuthRequest(t, session, "DELETE", "/apis/web/v1/album?id=1", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/album?id=1")
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
resp, err = makeAuthRequest(t, session, "DELETE", "/apis/web/v1/track?id=1", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/track?id=1")
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
truncateTestData(t)
}
func TestAliasesAndSearch(t *testing.T) {
t.Run("Submit Listens", doSubmitListens)
resp, err := makeAuthRequest(t, session, "POST", "/apis/web/v1/aliases?artist_id=1&alias=Sayuri", nil)
require.NoError(t, err)
require.Equal(t, 201, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/aliases?artist_id=1")
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
var actual []models.Alias
require.NoError(t, json.NewDecoder(resp.Body).Decode(&actual))
require.Len(t, actual, 2)
assert.Equal(t, actual[0].Alias, "さユり")
assert.Equal(t, actual[0].Source, "Canonical")
assert.Equal(t, actual[1].Alias, "Sayuri")
assert.Equal(t, actual[1].Source, "Manual")
resp, err = makeAuthRequest(t, session, "POST", "/apis/web/v1/aliases?album_id=1&alias=Sanketsu+Girl", nil)
require.NoError(t, err)
require.Equal(t, 201, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/aliases?album_id=1")
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
actual = nil
require.NoError(t, json.NewDecoder(resp.Body).Decode(&actual))
require.Len(t, actual, 2)
assert.Equal(t, actual[0].Alias, "酸欠少女")
assert.Equal(t, actual[0].Source, "Canonical")
assert.Equal(t, actual[1].Alias, "Sanketsu Girl")
assert.Equal(t, actual[1].Source, "Manual")
resp, err = makeAuthRequest(t, session, "POST", "/apis/web/v1/aliases?track_id=1&alias=Tower+of+Flower", nil)
require.NoError(t, err)
require.Equal(t, 201, resp.StatusCode)
resp, err = makeAuthRequest(t, session, "POST", "/apis/web/v1/aliases/primary?track_id=1&alias=Tower+of+Flower", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/track?id=1")
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
var track models.Track
require.NoError(t, json.NewDecoder(resp.Body).Decode(&track))
require.Len(t, actual, 2)
assert.Equal(t, track.Title, "Tower of Flower")
resp, err = makeAuthRequest(t, session, "POST", "/apis/web/v1/aliases/primary?artist_id=1&alias=Sayuri", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
// make sure searching works with aliases
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/search?q=Sanketsu")
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
var results handlers.SearchResults
require.NoError(t, json.NewDecoder(resp.Body).Decode(&results))
require.Len(t, results.Albums, 1)
assert.Equal(t, results.Albums[0].Title, "酸欠少女")
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/search?q=Sayuri")
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
results = handlers.SearchResults{}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&results))
require.Len(t, results.Artists, 1)
assert.Equal(t, results.Artists[0].Name, "Sayuri") // reflects the new primary alias
truncateTestData(t)
}
func TestStats(t *testing.T) {
// zeroes
resp, err := http.DefaultClient.Get(host() + "/apis/web/v1/stats")
t.Log(resp)
require.NoError(t, err)
t.Run("Submit Listens", doSubmitListens)
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/stats")
t.Log(resp)
require.NoError(t, err)
var actual handlers.StatsResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&actual))
assert.EqualValues(t, 3, actual.ListenCount)
assert.EqualValues(t, 3, actual.TrackCount)
assert.EqualValues(t, 3, actual.AlbumCount)
assert.EqualValues(t, 3, actual.ArtistCount)
assert.EqualValues(t, 0, actual.HoursListened)
}
func TestListenActivity(t *testing.T) {
// this test fails when run a bit after midnight
// i'll figure out a better test later
// t.Run("Submit Listens", doSubmitListens)
// resp, err := http.DefaultClient.Get(host() + "/apis/web/v1/listen-activity?range=3")
// t.Log(resp)
// require.NoError(t, err)
// var actual []db.ListenActivityItem
// require.NoError(t, json.NewDecoder(resp.Body).Decode(&actual))
// t.Log(actual)
// require.Len(t, actual, 3)
// assert.EqualValues(t, 3, actual[2].Listens)
}
func TestAuth(t *testing.T) {
// logs in a new session
formdata := url.Values{}
formdata.Set("username", cfg.DefaultUsername())
formdata.Set("password", cfg.DefaultPassword())
encoded := formdata.Encode()
resp, err := http.DefaultClient.Post(host()+"/apis/web/v1/login", "application/x-www-form-urlencoded", strings.NewReader(encoded))
respBytes, _ := io.ReadAll(resp.Body)
t.Logf("Login request response: %s - %s", resp.Status, respBytes)
require.NoError(t, err)
require.Len(t, resp.Cookies(), 1)
s := resp.Cookies()[0].Value
require.NotEmpty(t, s)
// test update user
req, err := http.NewRequest("PATCH", host()+"/apis/web/v1/user?username=new&password=supersecret", nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: s,
})
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
// test /me with updated info
req, err = http.NewRequest("GET", host()+"/apis/web/v1/user/me", nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: s,
})
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
var me models.User
require.NoError(t, json.NewDecoder(resp.Body).Decode(&me))
require.Equal(t, "new", me.Username)
// login with old password fails
formdata = url.Values{}
formdata.Set("username", cfg.DefaultUsername())
formdata.Set("password", cfg.DefaultPassword())
encoded = formdata.Encode()
resp, err = http.DefaultClient.Post(host()+"/apis/web/v1/login", "application/x-www-form-urlencoded", strings.NewReader(encoded))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
// reset update so other tests dont fail
req, err = http.NewRequest("PATCH", host()+fmt.Sprintf("/apis/web/v1/user?username=%s&password=%s", cfg.DefaultUsername(), cfg.DefaultPassword()), nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: s,
})
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
// creates api key
req, err = http.NewRequest("POST", host()+"/apis/web/v1/user/apikeys?label=testing", nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: s,
})
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 201, resp.StatusCode)
var response struct {
Key string `json:"key"`
}
require.NoError(t, json.NewDecoder(resp.Body).Decode(&response))
require.NotEmpty(t, response.Key)
// validates api key
req, err = http.NewRequest("GET", host()+"/apis/listenbrainz/1/validate-token", nil)
require.NoError(t, err)
req.Header.Add("Authorization", fmt.Sprintf("Token %s", response.Key))
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
var actual handlers.LbzValidateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&actual))
var expected handlers.LbzValidateResponse
expected.Code = 200
expected.Message = "Token valid."
expected.Valid = true
expected.UserName = "test"
assert.True(t, assert.ObjectsAreEqual(expected, actual))
// changes api key label
login(t) // i dont care about using the new session anymore
resp, err = makeAuthRequest(t, s, "PATCH", "/apis/web/v1/user/apikeys?id=2&label=well+tested", nil)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
resp, err = makeAuthRequest(t, s, "GET", "/apis/web/v1/user/apikeys", nil)
require.NoError(t, err)
var keys []models.ApiKey
err = json.NewDecoder(resp.Body).Decode(&keys)
require.NoError(t, err)
require.GreaterOrEqual(t, len(keys), 2)
require.NotNil(t, keys[1].Label)
assert.Equal(t, "well tested", keys[1].Label)
// logs out
req, err = http.NewRequest("POST", host()+"/apis/web/v1/logout", nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: s,
})
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
// attempts to create an api key - unauthorized
formdata = url.Values{}
formdata.Set("label", "testing")
encoded = formdata.Encode()
req, err = http.NewRequest("POST", host()+"/apis/web/v1/user/apikeys", strings.NewReader(encoded))
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: s,
})
resp, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 401, resp.StatusCode)
}
func TestDeleteListen(t *testing.T) {
login(t)
getApiKey(t, session)
truncateTestData(t)
body := `{
"listen_type": "single",
"payload": [
{
"listened_at": 1749475719,
"track_metadata": {
"additional_info": {
"artist_mbids": [
"80b3cb83-b7a3-4f79-ad42-8325cefb3626"
],
"artist_names": [
"キタニタツヤ"
],
"duration_ms": 197270,
"recording_mbid": "4e909c21-e7a8-404d-b75a-0c8c2926efb0",
"release_group_mbid": "89069d92-e495-462c-b189-3431551868ed",
"release_mbid": "e16a49d6-77f3-4d73-b93c-cac855ce6ad5",
"submission_client": "navidrome",
"submission_client_version": "0.56.1 (fa2cf362)"
},
"artist_name": "キタニタツヤ",
"release_name": "Where Our Blue Is",
"track_name": "Where Our Blue Is"
}
}
]
}`
req, err := http.NewRequest("POST", host()+"/apis/listenbrainz/1/submit-listens", strings.NewReader(body))
require.NoError(t, err)
req.Header.Add("Authorization", fmt.Sprintf("Token %s", apikey))
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
respBytes, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, `{"status": "ok"}`, string(respBytes))
resp, err = makeAuthRequest(t, session, "DELETE", "/apis/web/v1/listen?track_id=1&unix=1749475719", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
// deletes are idempotent
resp, err = makeAuthRequest(t, session, "DELETE", "/apis/web/v1/listen?track_id=1&unix=1749475719", nil)
require.NoError(t, err)
require.Equal(t, 204, resp.StatusCode)
// listen is deleted
resp, err = http.DefaultClient.Get(host() + "/apis/web/v1/track?id=1")
require.NoError(t, err)
var track models.Track
err = json.NewDecoder(resp.Body).Decode(&track)
require.NoError(t, err)
assert.EqualValues(t, 0, track.ListenCount)
}
func TestArtistReplaceImage(t *testing.T) {
t.Run("Submit Listens", doSubmitListens)
buf := &bytes.Buffer{}
mpw := multipart.NewWriter(buf)
mpw.WriteField("artist_id", "1")
w, err := mpw.CreateFormFile("image", path.Join("..", "static", "yuu.jpg"))
require.NoError(t, err)
f, err := os.Open(path.Join("..", "static", "yuu.jpg"))
require.NoError(t, err)
defer f.Close()
_, err = io.Copy(w, f)
require.NoError(t, err)
require.NoError(t, mpw.Close())
req, err := http.NewRequest("POST", host()+"/apis/web/v1/replace-image", buf)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: session,
})
req.Header.Add("Content-Type", mpw.FormDataContentType())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
response := new(handlers.ReplaceImageResponse)
require.NoError(t, json.NewDecoder(resp.Body).Decode(response))
require.NotEmpty(t, response.Image)
newid, err := uuid.Parse(response.Image)
require.NoError(t, err)
a, err := store.GetArtist(context.Background(), db.GetArtistOpts{ID: 1})
require.NoError(t, err)
assert.NotNil(t, a.Image)
assert.Equal(t, newid, *a.Image)
}
func TestAlbumReplaceImage(t *testing.T) {
t.Run("Submit Listens", doSubmitListens)
buf := &bytes.Buffer{}
mpw := multipart.NewWriter(buf)
mpw.WriteField("album_id", "1")
w, err := mpw.CreateFormFile("image", path.Join("..", "static", "yuu.jpg"))
require.NoError(t, err)
f, err := os.Open(path.Join("..", "static", "yuu.jpg"))
require.NoError(t, err)
defer f.Close()
_, err = io.Copy(w, f)
require.NoError(t, err)
require.NoError(t, mpw.Close())
req, err := http.NewRequest("POST", host()+"/apis/web/v1/replace-image", buf)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "koito_session",
Value: session,
})
req.Header.Add("Content-Type", mpw.FormDataContentType())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
response := new(handlers.ReplaceImageResponse)
require.NoError(t, json.NewDecoder(resp.Body).Decode(response))
require.NotEmpty(t, response.Image)
newid, err := uuid.Parse(response.Image)
require.NoError(t, err)
a, err := store.GetAlbum(context.Background(), db.GetAlbumOpts{ID: 1})
require.NoError(t, err)
assert.NotNil(t, a.Image)
assert.Equal(t, newid, *a.Image)
}

View file

@ -0,0 +1,24 @@
package middleware
import (
"net/http"
"slices"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/logger"
)
func AllowedHosts(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := logger.Get()
if cfg.AllowAllHosts() {
next.ServeHTTP(w, r)
return
} else if slices.Contains(cfg.AllowedHosts(), r.Host) {
next.ServeHTTP(w, r)
return
}
l.Warn().Msgf("Request denied from host %s. If you want to allow requests like this, add the host to your %s variable", r.Host, cfg.ALLOWED_HOSTS_ENV)
w.WriteHeader(http.StatusForbidden)
})
}

View file

@ -0,0 +1,103 @@
package middleware
import (
"context"
"crypto/rand"
"math/big"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/go-chi/chi/v5/middleware"
"github.com/rs/zerolog"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/utils"
)
type RequestIDHook struct{}
func (h RequestIDHook) Run(e *zerolog.Event, level zerolog.Level, msg string) {
if ctx := e.GetCtx(); ctx != nil {
if reqID, ok := ctx.Value("requestID").(string); ok {
e.Str("request_id", reqID)
}
}
}
const requestIDKey MiddlwareContextKey = "requestID"
const base62Chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
func GenerateRequestID() string {
const length = 8 // ~0.23% chance of collision in 1M requests
id := make([]byte, length)
for i := 0; i < length; i++ {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(base62Chars))))
id[i] = base62Chars[n.Int64()]
}
return string(id)
}
func WithRequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqID := GenerateRequestID()
ctx := context.WithValue(r.Context(), requestIDKey, reqID)
w.Header().Set("X-Request-ID", reqID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetRequestID extracts the request ID from context
func GetRequestID(ctx context.Context) string {
if val, ok := ctx.Value(requestIDKey).(string); ok {
return val
}
return ""
}
// Logger logs requests and injects a request-scoped logger with a request ID into the context.
func Logger(baseLogger *zerolog.Logger) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
reqID := GetRequestID(r.Context())
l := baseLogger.With().Str("request_id", reqID).Logger()
// Inject logger with request_id into the context
r = logger.Inject(r, &l)
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
defer func() {
t2 := time.Now()
if rec := recover(); rec != nil {
l.Error().
Str("type", "error").
Timestamp().
Interface("recover_info", rec).
Bytes("debug_stack", debug.Stack()).
Msg("log system error")
utils.WriteError(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
pathS := strings.Split(r.URL.Path, "/")
if len(pathS) > 1 && pathS[1] == "apis" {
l.Info().
Str("type", "access").
Timestamp().
Msgf("Received %s %s - Responded with %d in %.2fms", r.Method, r.URL.Path, ww.Status(), float64(t2.Sub(t1).Nanoseconds())/1_000_000.0)
} else {
l.Debug().
Str("type", "access").
Timestamp().
Msgf("Received %s %s - Responded with %d in %.2fms", r.Method, r.URL.Path, ww.Status(), float64(t2.Sub(t1).Nanoseconds())/1_000_000.0)
}
}()
next.ServeHTTP(ww, r)
}
return http.HandlerFunc(fn)
}
}

View file

@ -0,0 +1,106 @@
package middleware
import (
"context"
"net/http"
"strings"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
)
type MiddlwareContextKey string
const (
UserContextKey MiddlwareContextKey = "user"
apikeyContextKey MiddlwareContextKey = "apikeyID"
)
func ValidateSession(store db.DB) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := logger.FromContext(r.Context())
cookie, err := r.Cookie("koito_session")
var sid uuid.UUID
if err == nil {
sid, err = uuid.Parse(cookie.Value)
if err != nil {
utils.WriteError(w, "session cookie is invalid", http.StatusUnauthorized)
return
}
}
l.Debug().Msg("Retrieved login cookie from request")
u, err := store.GetUserBySession(r.Context(), sid)
if err != nil {
l.Err(err).Msg("Failed to get user from session")
utils.WriteError(w, "internal server error", http.StatusInternalServerError)
return
}
if u == nil {
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), UserContextKey, u)
r = r.WithContext(ctx)
l.Debug().Msgf("Refreshing session for user '%s'", u.Username)
store.RefreshSession(r.Context(), sid, time.Now().Add(30*24*time.Hour))
l.Debug().Msgf("Refreshed session for user '%s'", u.Username)
next.ServeHTTP(w, r)
})
}
}
func ValidateApiKey(store db.DB) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := logger.FromContext(ctx)
authh := r.Header.Get("Authorization")
s := strings.Split(authh, "Token ")
if len(s) < 2 {
l.Debug().Msg("Authorization header must be formatted 'Token {token}'")
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
key := s[1]
u, err := store.GetUserByApiKey(ctx, key)
if err != nil {
l.Err(err).Msg("Failed to get user from database using api key")
utils.WriteError(w, "internal server error", http.StatusInternalServerError)
return
}
if u == nil {
l.Debug().Msg("Api key does not exist")
utils.WriteError(w, "unauthorized", http.StatusUnauthorized)
return
}
ctx = context.WithValue(r.Context(), UserContextKey, u)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
func GetUserFromContext(ctx context.Context) *models.User {
user, ok := ctx.Value(UserContextKey).(*models.User)
if !ok {
return nil
}
return user
}

140
engine/routes.go Normal file
View file

@ -0,0 +1,140 @@
package engine
import (
"net/http"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/gabehf/koito/engine/handlers"
"github.com/gabehf/koito/engine/middleware"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
mbz "github.com/gabehf/koito/internal/mbz"
"github.com/go-chi/chi/v5"
chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/go-chi/httprate"
)
func bindRoutes(
r *chi.Mux,
ready *atomic.Bool,
db db.DB,
mbz mbz.MusicBrainzCaller,
) {
r.With(chimiddleware.RequestSize(5<<20)).
With(middleware.AllowedHosts).
Get("/images/{size}/{filename}", handlers.ImageHandler(db))
r.Route("/apis/web/v1", func(r chi.Router) {
r.Use(middleware.AllowedHosts)
r.Get("/artist", handlers.GetArtistHandler(db))
r.Get("/album", handlers.GetAlbumHandler(db))
r.Get("/track", handlers.GetTrackHandler(db))
r.Get("/top-tracks", handlers.GetTopTracksHandler(db))
r.Get("/top-albums", handlers.GetTopAlbumsHandler(db))
r.Get("/top-artists", handlers.GetTopArtistsHandler(db))
r.Get("/listens", handlers.GetListensHandler(db))
r.Get("/listen-activity", handlers.GetListenActivityHandler(db))
r.Get("/stats", handlers.StatsHandler(db))
r.Get("/search", handlers.SearchHandler(db))
r.Get("/aliases", handlers.GetAliasesHandler(db))
r.Post("/logout", handlers.LogoutHandler(db))
if !cfg.RateLimitDisabled() {
r.With(httprate.Limit(
10,
time.Minute,
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"too many requests"}`, http.StatusTooManyRequests)
}),
)).Post("/login", handlers.LoginHandler(db))
} else {
r.Post("/login", handlers.LoginHandler(db))
}
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
if !ready.Load() {
http.Error(w, "not ready", http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
})
r.Group(func(r chi.Router) {
r.Use(middleware.ValidateSession(db))
r.Post("/replace-image", handlers.ReplaceImageHandler(db))
r.Post("/merge/tracks", handlers.MergeTracksHandler(db))
r.Post("/merge/albums", handlers.MergeReleaseGroupsHandler(db))
r.Post("/merge/artists", handlers.MergeArtistsHandler(db))
r.Delete("/artist", handlers.DeleteArtistHandler(db))
r.Delete("/album", handlers.DeleteAlbumHandler(db))
r.Delete("/track", handlers.DeleteTrackHandler(db))
r.Delete("/listen", handlers.DeleteListenHandler(db))
r.Post("/aliases", handlers.CreateAliasHandler(db))
r.Delete("/aliases", handlers.DeleteAliasHandler(db))
r.Post("/aliases/primary", handlers.SetPrimaryAliasHandler(db))
r.Get("/user/apikeys", handlers.GetApiKeysHandler(db))
r.Post("/user/apikeys", handlers.GenerateApiKeyHandler(db))
r.Patch("/user/apikeys", handlers.UpdateApiKeyLabelHandler(db))
r.Delete("/user/apikeys", handlers.DeleteApiKeyHandler(db))
r.Get("/user/me", handlers.MeHandler(db))
r.Patch("/user", handlers.UpdateUserHandler(db))
})
})
r.Route("/apis/listenbrainz/1", func(r chi.Router) {
r.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"Content-Type", "Authorization"},
}))
r.With(middleware.ValidateApiKey(db)).Post("/submit-listens", handlers.LbzSubmitListenHandler(db, mbz))
r.With(middleware.ValidateApiKey(db)).Get("/validate-token", handlers.LbzValidateTokenHandler(db))
})
// serve react client
workDir, _ := os.Getwd()
filesDir := http.Dir(filepath.Join(workDir, "client/build/client"))
fileServer(r, "/", filesDir)
// serve client public files
filesDir = http.Dir(filepath.Join(workDir, "client/public"))
publicServer(r, "/public", filesDir)
}
// FileServer conveniently sets up a http.FileServer handler to serve
// static files from a http.FileSystem.
func fileServer(r chi.Router, path string, root http.FileSystem) {
if strings.ContainsAny(path, "{}*") {
panic("FileServer does not permit any URL parameters.")
}
// Serve static files
fs := http.FileServer(root)
r.Get(path+"*", func(w http.ResponseWriter, r *http.Request) {
// Check if file exists
filePath := filepath.Join("client/build/client", strings.TrimPrefix(r.URL.Path, path))
if _, err := os.Stat(filePath); os.IsNotExist(err) {
// File doesn't exist, serve index.html
http.ServeFile(w, r, filepath.Join("client/build/client", "index.html"))
return
}
// Serve file normally
fs.ServeHTTP(w, r)
})
}
func publicServer(r chi.Router, path string, root http.FileSystem) {
if strings.ContainsAny(path, "{}*") {
panic("FileServer does not permit any URL parameters.")
}
fs := http.FileServer(root)
r.Get(path+"*", func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.TrimPrefix(r.URL.Path, path)
fs.ServeHTTP(w, r)
})
}