chore: initial public commit

This commit is contained in:
Gabe Farrell 2025-06-11 19:45:39 -04:00
commit fc9054b78c
250 changed files with 32809 additions and 0 deletions

82
internal/db/db.go Normal file
View file

@ -0,0 +1,82 @@
// package db defines the database interface
package db
import (
"context"
"time"
"github.com/gabehf/koito/internal/models"
"github.com/google/uuid"
)
type DB interface {
// Get
GetArtist(ctx context.Context, opts GetArtistOpts) (*models.Artist, error)
GetAlbum(ctx context.Context, opts GetAlbumOpts) (*models.Album, error)
GetTrack(ctx context.Context, opts GetTrackOpts) (*models.Track, error)
GetTopTracksPaginated(ctx context.Context, opts GetItemsOpts) (*PaginatedResponse[*models.Track], error)
GetTopArtistsPaginated(ctx context.Context, opts GetItemsOpts) (*PaginatedResponse[*models.Artist], error)
GetTopAlbumsPaginated(ctx context.Context, opts GetItemsOpts) (*PaginatedResponse[*models.Album], error)
GetListensPaginated(ctx context.Context, opts GetItemsOpts) (*PaginatedResponse[*models.Listen], error)
GetListenActivity(ctx context.Context, opts ListenActivityOpts) ([]ListenActivityItem, error)
GetAllArtistAliases(ctx context.Context, id int32) ([]models.Alias, error)
GetAllAlbumAliases(ctx context.Context, id int32) ([]models.Alias, error)
GetAllTrackAliases(ctx context.Context, id int32) ([]models.Alias, error)
GetApiKeysByUserID(ctx context.Context, id int32) ([]models.ApiKey, error)
GetUserBySession(ctx context.Context, sessionId uuid.UUID) (*models.User, error)
GetUserByUsername(ctx context.Context, username string) (*models.User, error)
GetUserByApiKey(ctx context.Context, key string) (*models.User, error)
// Save
SaveArtist(ctx context.Context, opts SaveArtistOpts) (*models.Artist, error)
SaveArtistAliases(ctx context.Context, id int32, aliases []string, source string) error
SaveAlbum(ctx context.Context, opts SaveAlbumOpts) (*models.Album, error)
SaveAlbumAliases(ctx context.Context, id int32, aliases []string, source string) error
SaveTrack(ctx context.Context, opts SaveTrackOpts) (*models.Track, error)
SaveTrackAliases(ctx context.Context, id int32, aliases []string, source string) error
SaveListen(ctx context.Context, opts SaveListenOpts) error
SaveUser(ctx context.Context, opts SaveUserOpts) (*models.User, error)
SaveApiKey(ctx context.Context, opts SaveApiKeyOpts) (*models.ApiKey, error)
SaveSession(ctx context.Context, userId int32, expiresAt time.Time, persistent bool) (*models.Session, error)
// Update
UpdateArtist(ctx context.Context, opts UpdateArtistOpts) error
UpdateTrack(ctx context.Context, opts UpdateTrackOpts) error
UpdateAlbum(ctx context.Context, opts UpdateAlbumOpts) error
AddArtistsToAlbum(ctx context.Context, opts AddArtistsToAlbumOpts) error
UpdateUser(ctx context.Context, opts UpdateUserOpts) error
UpdateApiKeyLabel(ctx context.Context, opts UpdateApiKeyLabelOpts) error
RefreshSession(ctx context.Context, sessionId uuid.UUID, expiresAt time.Time) error
SetPrimaryArtistAlias(ctx context.Context, id int32, alias string) error
SetPrimaryAlbumAlias(ctx context.Context, id int32, alias string) error
SetPrimaryTrackAlias(ctx context.Context, id int32, alias string) error
// Delete
DeleteArtist(ctx context.Context, id int32) error
DeleteAlbum(ctx context.Context, id int32) error
DeleteTrack(ctx context.Context, id int32) error
DeleteListen(ctx context.Context, trackId int32, listenedAt time.Time) error
DeleteArtistAlias(ctx context.Context, id int32, alias string) error
DeleteAlbumAlias(ctx context.Context, id int32, alias string) error
DeleteTrackAlias(ctx context.Context, id int32, alias string) error
DeleteSession(ctx context.Context, sessionId uuid.UUID) error
DeleteApiKey(ctx context.Context, id int32) error
// Count
CountListens(ctx context.Context, period Period) (int64, error)
CountTracks(ctx context.Context, period Period) (int64, error)
CountAlbums(ctx context.Context, period Period) (int64, error)
CountArtists(ctx context.Context, period Period) (int64, error)
CountTimeListened(ctx context.Context, period Period) (int64, error)
CountUsers(ctx context.Context) (int64, error)
// Search
SearchArtists(ctx context.Context, q string) ([]*models.Artist, error)
SearchAlbums(ctx context.Context, q string) ([]*models.Album, error)
SearchTracks(ctx context.Context, q string) ([]*models.Track, error)
// Merge
MergeTracks(ctx context.Context, fromId, toId int32) error
MergeAlbums(ctx context.Context, fromId, toId int32) error
MergeArtists(ctx context.Context, fromId, toId int32) error
// Etc
ImageHasAssociation(ctx context.Context, image uuid.UUID) (bool, error)
GetImageSource(ctx context.Context, image uuid.UUID) (string, error)
AlbumsWithoutImages(ctx context.Context, from int32) ([]*models.Album, error)
Ping(ctx context.Context) error
Close(ctx context.Context)
}

140
internal/db/opts.go Normal file
View file

@ -0,0 +1,140 @@
package db
import (
"time"
"github.com/gabehf/koito/internal/models"
"github.com/google/uuid"
)
type GetAlbumOpts struct {
ID int32
MusicBrainzID uuid.UUID
ArtistID int32
Title string
Titles []string
Image uuid.UUID
}
type GetArtistOpts struct {
ID int32
MusicBrainzID uuid.UUID
Name string
Image uuid.UUID
}
type GetTrackOpts struct {
ID int32
MusicBrainzID uuid.UUID
Title string
ArtistIDs []int32
}
type SaveTrackOpts struct {
Title string
AlbumID int32
ArtistIDs []int32
RecordingMbzID uuid.UUID
Duration int32
}
type SaveAlbumOpts struct {
Title string
MusicBrainzID uuid.UUID
Type string
ArtistIDs []int32
VariousArtists bool
Image uuid.UUID
ImageSrc string
Aliases []string
}
type SaveArtistOpts struct {
Name string
MusicBrainzID uuid.UUID
Aliases []string
Image uuid.UUID
ImageSrc string
}
type UpdateApiKeyLabelOpts struct {
UserID int32
ID int32
Label string
}
type SaveUserOpts struct {
Username string
Password string
Role models.UserRole
}
type SaveApiKeyOpts struct {
Key string
UserID int32
Label string
}
type SaveListenOpts struct {
TrackID int32
Time time.Time
UserID int32
Client string
}
type UpdateTrackOpts struct {
ID int32
MusicBrainzID uuid.UUID
Duration int32
}
type UpdateArtistOpts struct {
ID int32
MusicBrainzID uuid.UUID
Image uuid.UUID
ImageSrc string
}
type UpdateAlbumOpts struct {
ID int32
MusicBrainzID uuid.UUID
Image uuid.UUID
ImageSrc string
}
type UpdateUserOpts struct {
ID int32
Username string
Password string
}
type AddArtistsToAlbumOpts struct {
AlbumID int32
ArtistIDs []int32
}
type GetItemsOpts struct {
Limit int
Period Period
Page int
Week int // 1-52
Month int // 1-12
Year int
// Used only for getting top tracks
ArtistID int
AlbumID int
// Used for getting listens
TrackID int
}
type ListenActivityOpts struct {
Step StepInterval
Range int
Month int
Year int
AlbumID int32
ArtistID int32
TrackID int32
}

108
internal/db/period.go Normal file
View file

@ -0,0 +1,108 @@
package db
import (
"time"
)
// should this be in db package ???
type Period string
const (
PeriodDay Period = "day"
PeriodWeek Period = "week"
PeriodMonth Period = "month"
PeriodYear Period = "year"
PeriodAllTime Period = "all_time"
PeriodDefault Period = "day"
)
func StartTimeFromPeriod(p Period) time.Time {
now := time.Now()
switch p {
case "day":
return now.AddDate(0, 0, -1)
case "week":
return now.AddDate(0, 0, -7)
case "month":
return now.AddDate(0, -1, 0)
case "year":
return now.AddDate(-1, 0, 0)
case "all_time":
return time.Time{}
default:
// default 1 day
return now.AddDate(0, 0, -1)
}
}
type StepInterval string
const (
StepDay StepInterval = "day"
StepWeek StepInterval = "week"
StepMonth StepInterval = "month"
StepYear StepInterval = "year"
StepDefault StepInterval = "day"
DefaultRange int = 12
)
// start is the time of 00:00 at the beginning of opts.Range opts.Steps ago,
// end is the end time of the current opts.Step.
// E.g. if step is StepWeek and range is 4, start will be the time 00:00 on Sunday on the 4th week ago,
// and end will be 23:59:59 on Saturday at the end of the current week.
// If opts.Year (or opts.Year + opts.Month) is provided, start and end will simply by the start and end times of that year/month.
func ListenActivityOptsToTimes(opts ListenActivityOpts) (start, end time.Time) {
now := time.Now()
// If Year (and optionally Month) are specified, use calendar boundaries
if opts.Year != 0 {
if opts.Month != 0 {
// Specific month of a specific year
start = time.Date(opts.Year, time.Month(opts.Month), 1, 0, 0, 0, 0, now.Location())
end = start.AddDate(0, 1, 0).Add(-time.Nanosecond)
} else {
// Whole year
start = time.Date(opts.Year, 1, 1, 0, 0, 0, 0, now.Location())
end = start.AddDate(1, 0, 0).Add(-time.Nanosecond)
}
return start, end
}
// X days ago + today = range
opts.Range = opts.Range - 1
// Determine step and align accordingly
switch opts.Step {
case StepDay:
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
start = today.AddDate(0, 0, -opts.Range)
end = today.AddDate(0, 0, 1).Add(-time.Nanosecond)
case StepWeek:
// Align to most recent Sunday
weekday := int(now.Weekday()) // Sunday = 0
startOfThisWeek := time.Date(now.Year(), now.Month(), now.Day()-weekday, 0, 0, 0, 0, now.Location())
start = startOfThisWeek.AddDate(0, 0, -7*opts.Range)
end = startOfThisWeek.AddDate(0, 0, 7).Add(-time.Nanosecond)
case StepMonth:
firstOfThisMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
start = firstOfThisMonth.AddDate(0, -opts.Range, 0)
end = firstOfThisMonth.AddDate(0, 1, 0).Add(-time.Nanosecond)
case StepYear:
firstOfThisYear := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location())
start = firstOfThisYear.AddDate(-opts.Range, 0, 0)
end = firstOfThisYear.AddDate(1, 0, 0).Add(-time.Nanosecond)
default:
// Default to daily
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
start = today.AddDate(0, 0, -opts.Range)
end = today.AddDate(0, 0, 1).Add(-time.Nanosecond)
}
return start, end
}

View file

@ -0,0 +1,28 @@
package db_test
import (
"testing"
"time"
)
func TestListenActivityOptsToTimes(t *testing.T) {
// default range
// opts := db.ListenActivityOpts{}
// t1, t2 := db.ListenActivityOptsToTimes(opts)
// t.Logf("%s to %s", t1, t2)
// assert.WithinDuration(t, bod(time.Now().Add(-11*24*time.Hour)), t1, 5*time.Second)
// assert.WithinDuration(t, eod(time.Now()), t2, 5*time.Second)
}
func eod(t time.Time) time.Time {
year, month, day := t.Date()
loc := t.Location()
return time.Date(year, month, day, 23, 59, 59, 0, loc)
}
func bod(t time.Time) time.Time {
year, month, day := t.Date()
loc := t.Location()
return time.Date(year, month, day, 0, 0, 0, 0, loc)
}

312
internal/db/psql/album.go Normal file
View file

@ -0,0 +1,312 @@
package psql
import (
"context"
"errors"
"strings"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
func (d *Psql) GetAlbum(ctx context.Context, opts db.GetAlbumOpts) (*models.Album, error) {
l := logger.FromContext(ctx)
var row repository.ReleasesWithTitle
var err error
if opts.ID != 0 {
l.Debug().Msgf("Fetching album from DB with id %d", opts.ID)
row, err = d.q.GetRelease(ctx, opts.ID)
} else if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Fetching album from DB with MusicBrainz Release ID %s", opts.MusicBrainzID)
row, err = d.q.GetReleaseByMbzID(ctx, &opts.MusicBrainzID)
} else if opts.ArtistID != 0 && opts.Title != "" {
l.Debug().Msgf("Fetching album from DB with artist_id %d and title %s", opts.ArtistID, opts.Title)
row, err = d.q.GetReleaseByArtistAndTitle(ctx, repository.GetReleaseByArtistAndTitleParams{
ArtistID: opts.ArtistID,
Title: opts.Title,
})
} else if opts.ArtistID != 0 && len(opts.Titles) > 0 {
l.Debug().Msgf("Fetching release group from DB with artist_id %d and titles %v", opts.ArtistID, opts.Titles)
row, err = d.q.GetReleaseByArtistAndTitles(ctx, repository.GetReleaseByArtistAndTitlesParams{
ArtistID: opts.ArtistID,
Column1: opts.Titles,
})
} else {
return nil, errors.New("insufficient information to get album")
}
if err != nil {
return nil, err
}
count, err := d.q.CountListensFromRelease(ctx, repository.CountListensFromReleaseParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
ReleaseID: row.ID,
})
if err != nil {
return nil, err
}
return &models.Album{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
Image: row.Image,
VariousArtists: row.VariousArtists,
ListenCount: count,
}, nil
}
func (d *Psql) SaveAlbum(ctx context.Context, opts db.SaveAlbumOpts) (*models.Album, error) {
l := logger.FromContext(ctx)
var insertMbzID *uuid.UUID
var insertImage *uuid.UUID
if opts.MusicBrainzID != uuid.Nil {
insertMbzID = &opts.MusicBrainzID
}
if opts.Image != uuid.Nil {
insertImage = &opts.Image
}
if len(opts.ArtistIDs) < 1 {
return nil, errors.New("required parameter 'ArtistIDs' missing")
}
for _, aid := range opts.ArtistIDs {
if aid == 0 {
return nil, errors.New("none of 'ArtistIDs' may be 0")
}
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return nil, err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
l.Debug().Msgf("Inserting release '%s' into DB", opts.Title)
r, err := qtx.InsertRelease(ctx, repository.InsertReleaseParams{
MusicBrainzID: insertMbzID,
VariousArtists: opts.VariousArtists,
Image: insertImage,
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
})
if err != nil {
return nil, err
}
for _, artistId := range opts.ArtistIDs {
l.Debug().Msgf("Associating release '%s' to artist with ID %d", opts.Title, artistId)
err = qtx.AssociateArtistToRelease(ctx, repository.AssociateArtistToReleaseParams{
ArtistID: artistId,
ReleaseID: r.ID,
})
if err != nil {
return nil, err
}
}
l.Debug().Msgf("Saving canonical alias %s for release %d", opts.Title, r.ID)
err = qtx.InsertReleaseAlias(ctx, repository.InsertReleaseAliasParams{
ReleaseID: r.ID,
Alias: opts.Title,
Source: "Canonical",
IsPrimary: true,
})
if err != nil {
l.Err(err).Msgf("Failed to save canonical alias for album %d", r.ID)
}
err = tx.Commit(ctx)
if err != nil {
return nil, err
}
return &models.Album{
ID: r.ID,
MbzID: r.MusicBrainzID,
Title: opts.Title,
Image: r.Image,
VariousArtists: r.VariousArtists,
}, nil
}
func (d *Psql) AddArtistsToAlbum(ctx context.Context, opts db.AddArtistsToAlbumOpts) error {
l := logger.FromContext(ctx)
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
for _, id := range opts.ArtistIDs {
err := qtx.AssociateArtistToRelease(ctx, repository.AssociateArtistToReleaseParams{
ReleaseID: opts.AlbumID,
ArtistID: id,
})
if err != nil {
l.Error().Err(err).Msgf("Failed to associate release %d with artist %d", opts.AlbumID, id)
}
}
return tx.Commit(ctx)
}
func (d *Psql) UpdateAlbum(ctx context.Context, opts db.UpdateAlbumOpts) error {
l := logger.FromContext(ctx)
if opts.ID == 0 {
return errors.New("missing album id")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Updating release with ID %d with MusicBrainz ID %s", opts.ID, opts.MusicBrainzID)
err := qtx.UpdateReleaseMbzID(ctx, repository.UpdateReleaseMbzIDParams{
ID: opts.ID,
MusicBrainzID: &opts.MusicBrainzID,
})
if err != nil {
return err
}
}
if opts.Image != uuid.Nil {
l.Debug().Msgf("Updating release with ID %d with image %s", opts.ID, opts.Image)
err := qtx.UpdateReleaseImage(ctx, repository.UpdateReleaseImageParams{
ID: opts.ID,
Image: &opts.Image,
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) SaveAlbumAliases(ctx context.Context, id int32, aliases []string, source string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("album id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
existing, err := qtx.GetAllReleaseAliases(ctx, id)
if err != nil {
return err
}
for _, v := range existing {
aliases = append(aliases, v.Alias)
}
utils.Unique(&aliases)
for _, alias := range aliases {
if strings.TrimSpace(alias) == "" {
return errors.New("aliases cannot be blank")
}
err = qtx.InsertReleaseAlias(ctx, repository.InsertReleaseAliasParams{
Alias: strings.TrimSpace(alias),
ReleaseID: id,
Source: source,
IsPrimary: false,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) DeleteAlbum(ctx context.Context, id int32) error {
return d.q.DeleteRelease(ctx, id)
}
func (d *Psql) DeleteAlbumAlias(ctx context.Context, id int32, alias string) error {
return d.q.DeleteReleaseAlias(ctx, repository.DeleteReleaseAliasParams{
ReleaseID: id,
Alias: alias,
})
}
func (d *Psql) GetAllAlbumAliases(ctx context.Context, id int32) ([]models.Alias, error) {
rows, err := d.q.GetAllReleaseAliases(ctx, id)
if err != nil {
return nil, err
}
aliases := make([]models.Alias, len(rows))
for i, row := range rows {
aliases[i] = models.Alias{
ID: id,
Alias: row.Alias,
Source: row.Source,
Primary: row.IsPrimary,
}
}
return aliases, nil
}
func (d *Psql) SetPrimaryAlbumAlias(ctx context.Context, id int32, alias string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
// get all aliases
aliases, err := qtx.GetAllReleaseAliases(ctx, id)
if err != nil {
return err
}
primary := ""
exists := false
for _, v := range aliases {
if v.Alias == alias {
exists = true
}
if v.IsPrimary {
primary = v.Alias
}
}
if primary == alias {
// no-op rename
return nil
}
if !exists {
return errors.New("alias does not exist")
}
err = qtx.SetReleaseAliasPrimaryStatus(ctx, repository.SetReleaseAliasPrimaryStatusParams{
ReleaseID: id,
Alias: alias,
IsPrimary: true,
})
if err != nil {
return err
}
err = qtx.SetReleaseAliasPrimaryStatus(ctx, repository.SetReleaseAliasPrimaryStatusParams{
ReleaseID: id,
Alias: primary,
IsPrimary: false,
})
if err != nil {
return err
}
return tx.Commit(ctx)
}

View file

@ -0,0 +1,319 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/catalog"
"github.com/gabehf/koito/internal/db"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func truncateTestData(t *testing.T) {
err := store.Exec(context.Background(),
`TRUNCATE
artists,
artist_aliases,
tracks,
artist_tracks,
releases,
artist_releases,
release_aliases,
listens
RESTART IDENTITY CASCADE`)
require.NoError(t, err)
}
func testDataForRelease(t *testing.T) {
truncateTestData(t)
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'ATARASHII GAKKO!', 'MusicBrainz', true)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (2, 'Masayuki Suzuki', 'MusicBrainz', true)`)
require.NoError(t, err)
}
func TestGetAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test data
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Release Group",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Test GetAlbum by ID
result, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
require.NoError(t, err)
assert.Equal(t, rg.ID, result.ID)
assert.Equal(t, "Test Release Group", result.Title)
// Test GetAlbum with insufficient information
_, err = store.GetAlbum(ctx, db.GetAlbumOpts{})
assert.Error(t, err)
truncateTestData(t)
}
func TestSaveAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Save release group with artist IDs
artistIDs := []int32{1, 2}
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "New Release Group",
ArtistIDs: artistIDs,
})
require.NoError(t, err)
// Verify release group was saved
assert.Equal(t, "New Release Group", rg.Title)
// Verify release was created for release group
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM releases_with_title
WHERE title = $1 AND id = $2
)`, "New Release Group", rg.ID)
require.NoError(t, err)
assert.True(t, exists, "expected release to exist")
// Verify artist associations were created for release group
for _, aid := range artistIDs {
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_releases
WHERE artist_id = $1 AND release_id = $2
)`, aid, rg.ID)
require.NoError(t, err)
assert.True(t, exists, "expected artist association to exist")
}
truncateTestData(t)
}
func TestUpdateAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Old Title",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
newMbzID := uuid.New()
imgid := uuid.New()
err = store.UpdateAlbum(ctx, db.UpdateAlbumOpts{
ID: rg.ID,
MusicBrainzID: newMbzID,
Image: imgid,
ImageSrc: catalog.ImageSourceUserUpload,
})
require.NoError(t, err)
result, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
require.NoError(t, err)
assert.Equal(t, newMbzID, *result.MbzID)
assert.Equal(t, imgid, *result.Image)
truncateTestData(t)
}
func TestAddArtistsToAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test album
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Album",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Add additional artists to the album
err = store.AddArtistsToAlbum(ctx, db.AddArtistsToAlbumOpts{
AlbumID: rg.ID,
ArtistIDs: []int32{2},
})
require.NoError(t, err)
// Verify artist associations were created
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_releases
WHERE artist_id = $1 AND release_id = $2
)`, 2, rg.ID)
require.NoError(t, err)
assert.True(t, exists, "expected artist association to exist")
truncateTestData(t)
}
func TestSaveAlbumAliases(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test album
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Album",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Save aliases for the album
aliases := []string{"Alias 1", "Alias 2"}
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
require.NoError(t, err)
// Verify aliases were saved
for _, alias := range aliases {
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM release_aliases
WHERE release_id = $1 AND alias = $2
)`, rg.ID, alias)
require.NoError(t, err)
assert.True(t, exists, "expected alias to exist")
}
err = store.SetPrimaryAlbumAlias(ctx, 1, "Alias 1")
require.NoError(t, err)
album, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
require.NoError(t, err)
assert.Equal(t, "Alias 1", album.Title)
err = store.SetPrimaryAlbumAlias(ctx, 1, "Fake Alias")
require.Error(t, err)
store.SetPrimaryAlbumAlias(ctx, 1, "Album One")
truncateTestData(t)
}
func TestDeleteAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
testDataForTopItems(t)
// Delete the album
err := store.DeleteAlbum(ctx, 1)
require.NoError(t, err)
// Verify album was deleted
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM releases
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected album to be deleted")
// Verify album's track was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM tracks
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected album's tracks to be deleted")
// Verify album's listens was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM listens
WHERE track_id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected album's listens to be deleted")
truncateTestData(t)
}
func TestDeleteAlbumAlias(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test album
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Album",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Save aliases for the album
aliases := []string{"Alias 1", "Alias 2"}
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
require.NoError(t, err)
// Delete one alias
err = store.DeleteAlbumAlias(ctx, rg.ID, "Alias 1")
require.NoError(t, err)
// Verify alias was deleted
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM release_aliases
WHERE release_id = $1 AND alias = $2
)`, rg.ID, "Alias 1")
require.NoError(t, err)
assert.False(t, exists, "expected alias to be deleted")
// Verify other alias still exists
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM release_aliases
WHERE release_id = $1 AND alias = $2
)`, rg.ID, "Alias 2")
require.NoError(t, err)
assert.True(t, exists, "expected alias to still exist")
truncateTestData(t)
}
func TestGetAllAlbumAliases(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test album
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Album",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Save aliases for the album
aliases := []string{"Alias 1", "Alias 2"}
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
require.NoError(t, err)
// Retrieve all aliases
result, err := store.GetAllAlbumAliases(ctx, rg.ID)
require.NoError(t, err)
assert.Len(t, result, len(aliases)+1) // new + canonical
for _, alias := range aliases {
found := false
for _, res := range result {
if res.Alias == alias {
found = true
break
}
}
assert.True(t, found, "expected alias to be retrieved")
}
truncateTestData(t)
}

309
internal/db/psql/artist.go Normal file
View file

@ -0,0 +1,309 @@
package psql
import (
"context"
"errors"
"strings"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
func (d *Psql) GetArtist(ctx context.Context, opts db.GetArtistOpts) (*models.Artist, error) {
l := logger.FromContext(ctx)
if opts.ID != 0 {
l.Debug().Msgf("Fetching artist from DB with id %d", opts.ID)
row, err := d.q.GetArtist(ctx, opts.ID)
if err != nil {
return nil, err
}
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
ArtistID: row.ID,
})
if err != nil {
return nil, err
}
return &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Aliases: row.Aliases,
Image: row.Image,
ListenCount: count,
}, nil
} else if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Fetching artist from DB with MusicBrainz ID %s", opts.MusicBrainzID)
row, err := d.q.GetArtistByMbzID(ctx, &opts.MusicBrainzID)
if err != nil {
return nil, err
}
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
ArtistID: row.ID,
})
if err != nil {
return nil, err
}
return &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Aliases: row.Aliases,
Image: row.Image,
ListenCount: count,
}, nil
} else if opts.Name != "" {
l.Debug().Msgf("Fetching artist from DB with name '%s'", opts.Name)
row, err := d.q.GetArtistByName(ctx, opts.Name)
if err != nil {
return nil, err
}
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
ArtistID: row.ID,
})
if err != nil {
return nil, err
}
return &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Aliases: row.Aliases,
Image: row.Image,
ListenCount: count,
}, nil
} else {
return nil, errors.New("insufficient information to get artist")
}
}
// Inserts all unique aliases into the DB with specified source
func (d *Psql) SaveArtistAliases(ctx context.Context, id int32, aliases []string, source string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
existing, err := qtx.GetAllArtistAliases(ctx, id)
if err != nil {
return err
}
for _, v := range existing {
aliases = append(aliases, v.Alias)
}
utils.Unique(&aliases)
for _, alias := range aliases {
if strings.TrimSpace(alias) == "" {
return errors.New("aliases cannot be blank")
}
err = qtx.InsertArtistAlias(ctx, repository.InsertArtistAliasParams{
Alias: strings.TrimSpace(alias),
ArtistID: id,
Source: source,
IsPrimary: false,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) DeleteArtist(ctx context.Context, id int32) error {
return d.q.DeleteArtist(ctx, id)
}
// Equivalent to Psql.SaveArtist, then Psql.SaveMbzAliases
func (d *Psql) SaveArtist(ctx context.Context, opts db.SaveArtistOpts) (*models.Artist, error) {
l := logger.FromContext(ctx)
var insertMbzID *uuid.UUID
var insertImage *uuid.UUID
if opts.MusicBrainzID != uuid.Nil {
insertMbzID = &opts.MusicBrainzID
}
if opts.Image != uuid.Nil {
insertImage = &opts.Image
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return nil, err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
opts.Name = strings.TrimSpace(opts.Name)
if opts.Name == "" {
return nil, errors.New("name must not be blank")
}
l.Debug().Msgf("Inserting artist '%s' into DB", opts.Name)
a, err := qtx.InsertArtist(ctx, repository.InsertArtistParams{
MusicBrainzID: insertMbzID,
Image: insertImage,
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Inserting canonical alias '%s' into DB for artist with id %d", opts.Name, a.ID)
err = qtx.InsertArtistAlias(ctx, repository.InsertArtistAliasParams{
ArtistID: a.ID,
Alias: opts.Name,
Source: "Canonical",
IsPrimary: true,
})
if err != nil {
l.Error().Err(err).Msgf("Error inserting canonical alias for artist '%s'", opts.Name)
return nil, err
}
err = tx.Commit(ctx)
if err != nil {
l.Err(err).Msg("Failed to commit insert artist transaction")
return nil, err
}
artist := &models.Artist{
ID: a.ID,
Name: opts.Name,
Image: a.Image,
MbzID: a.MusicBrainzID,
Aliases: []string{opts.Name},
}
if len(opts.Aliases) > 0 {
l.Debug().Msgf("Inserting aliases '%v' into DB for artist '%s'", opts.Aliases, opts.Name)
err = d.SaveArtistAliases(ctx, a.ID, opts.Aliases, "MusicBrainz")
if err != nil {
return nil, err
}
artist.Aliases = opts.Aliases
}
return artist, nil
}
func (d *Psql) UpdateArtist(ctx context.Context, opts db.UpdateArtistOpts) error {
l := logger.FromContext(ctx)
if opts.ID == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Updating artist with id %d with MusicBrainz ID %s", opts.ID, opts.MusicBrainzID)
err := qtx.UpdateArtistMbzID(ctx, repository.UpdateArtistMbzIDParams{
ID: opts.ID,
MusicBrainzID: &opts.MusicBrainzID,
})
if err != nil {
return err
}
}
if opts.Image != uuid.Nil {
l.Debug().Msgf("Updating artist with id %d with image %s", opts.ID, opts.Image)
err = qtx.UpdateArtistImage(ctx, repository.UpdateArtistImageParams{
ID: opts.ID,
Image: &opts.Image,
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) DeleteArtistAlias(ctx context.Context, id int32, alias string) error {
return d.q.DeleteArtistAlias(ctx, repository.DeleteArtistAliasParams{
ArtistID: id,
Alias: alias,
})
}
func (d *Psql) GetAllArtistAliases(ctx context.Context, id int32) ([]models.Alias, error) {
rows, err := d.q.GetAllArtistAliases(ctx, id)
if err != nil {
return nil, err
}
aliases := make([]models.Alias, len(rows))
for i, row := range rows {
aliases[i] = models.Alias{
ID: id,
Alias: row.Alias,
Source: row.Source,
Primary: row.IsPrimary,
}
}
return aliases, nil
}
func (d *Psql) SetPrimaryArtistAlias(ctx context.Context, id int32, alias string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
// get all aliases
aliases, err := qtx.GetAllArtistAliases(ctx, id)
if err != nil {
return err
}
primary := ""
exists := false
for _, v := range aliases {
if v.Alias == alias {
exists = true
}
if v.IsPrimary {
primary = v.Alias
}
}
if primary == alias {
// no-op rename
return nil
}
if !exists {
return errors.New("alias does not exist")
}
err = qtx.SetArtistAliasPrimaryStatus(ctx, repository.SetArtistAliasPrimaryStatusParams{
ArtistID: id,
Alias: alias,
IsPrimary: true,
})
if err != nil {
return err
}
err = qtx.SetArtistAliasPrimaryStatus(ctx, repository.SetArtistAliasPrimaryStatusParams{
ArtistID: id,
Alias: primary,
IsPrimary: false,
})
if err != nil {
return err
}
return tx.Commit(ctx)
}

View file

@ -0,0 +1,247 @@
package psql_test
import (
"context"
"slices"
"testing"
"github.com/gabehf/koito/internal/catalog"
"github.com/gabehf/koito/internal/db"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetArtist(t *testing.T) {
ctx := context.Background()
mbzId := uuid.MustParse("00000000-0000-0000-0000-000000000001")
// Insert test data
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Test Artist",
MusicBrainzID: mbzId,
})
require.NoError(t, err)
// Test GetArtist by ID
result, err := store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
require.NoError(t, err)
assert.Equal(t, artist.ID, result.ID)
assert.Equal(t, "Test Artist", result.Name)
// Test GetArtist by Name
result, err = store.GetArtist(ctx, db.GetArtistOpts{Name: artist.Name})
require.NoError(t, err)
assert.Equal(t, artist.ID, result.ID)
// Test GetArtist by MusicBrainzID
result, err = store.GetArtist(ctx, db.GetArtistOpts{MusicBrainzID: mbzId})
require.NoError(t, err)
assert.Equal(t, artist.ID, result.ID)
// Test GetArtist with insufficient information
_, err = store.GetArtist(ctx, db.GetArtistOpts{})
assert.Error(t, err)
truncateTestData(t)
}
func TestSaveAliases(t *testing.T) {
ctx := context.Background()
// Insert test artist
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Alias Artist",
})
require.NoError(t, err)
// Save aliases
aliases := []string{"Alias1", "Alias2"}
err = store.SaveArtistAliases(ctx, artist.ID, aliases, "MusicBrainz")
require.NoError(t, err)
// Verify aliases were saved
for _, alias := range aliases {
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_aliases
WHERE artist_id = $1 AND alias = $2
)`, artist.ID, alias)
require.NoError(t, err)
assert.True(t, exists, "expected alias to exist")
}
err = store.SetPrimaryArtistAlias(ctx, 1, "Alias1")
require.NoError(t, err)
artist, err = store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
require.NoError(t, err)
assert.Equal(t, "Alias1", artist.Name)
err = store.SetPrimaryArtistAlias(ctx, 1, "Fake Alias")
require.Error(t, err)
truncateTestData(t)
}
func TestSaveArtist(t *testing.T) {
ctx := context.Background()
// Save artist with aliases
aliases := []string{"Alias1", "Alias2"}
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "New Artist",
Aliases: aliases,
})
require.NoError(t, err)
// Verify artist was saved
assert.Equal(t, "New Artist", artist.Name)
// Verify aliases were saved
for _, alias := range slices.Concat(aliases, []string{"New Artist"}) {
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_aliases
WHERE artist_id = $1 AND alias = $2
)`, artist.ID, alias)
require.NoError(t, err)
assert.True(t, exists, "expected alias '%s' to exist", alias)
}
truncateTestData(t)
}
func TestUpdateArtist(t *testing.T) {
ctx := context.Background()
// Insert test artist
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Old Name",
})
require.NoError(t, err)
imgid := uuid.New()
err = store.UpdateArtist(ctx, db.UpdateArtistOpts{
ID: artist.ID,
Image: imgid,
ImageSrc: catalog.ImageSourceUserUpload,
})
require.NoError(t, err)
result, err := store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
require.NoError(t, err)
assert.Equal(t, imgid, *result.Image)
truncateTestData(t)
}
func TestGetAllArtistAliases(t *testing.T) {
ctx := context.Background()
// Insert test artist
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Alias Artist",
Aliases: []string{"Alias1", "Alias2"},
})
require.NoError(t, err)
// Retrieve all aliases
result, err := store.GetAllArtistAliases(ctx, artist.ID)
require.NoError(t, err)
assert.Len(t, result, 3) // Includes canonical alias
// Verify aliases were retrieved
expectedAliases := []string{"Alias Artist", "Alias1", "Alias2"}
for _, alias := range expectedAliases {
found := false
for _, res := range result {
if res.Alias == alias {
found = true
break
}
}
assert.True(t, found, "expected alias '%s' to be retrieved", alias)
}
truncateTestData(t)
}
func TestDeleteArtistAlias(t *testing.T) {
ctx := context.Background()
// Insert test artist
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Alias Artist",
Aliases: []string{"Alias1", "Alias2"},
})
require.NoError(t, err)
// Delete one alias
err = store.DeleteArtistAlias(ctx, artist.ID, "Alias1")
require.NoError(t, err)
// Verify alias was deleted
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_aliases
WHERE artist_id = $1 AND alias = $2
)`, artist.ID, "Alias1")
require.NoError(t, err)
assert.False(t, exists, "expected alias to be deleted")
// Verify other alias still exists
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_aliases
WHERE artist_id = $1 AND alias = $2
)`, artist.ID, "Alias2")
require.NoError(t, err)
assert.True(t, exists, "expected alias to still exist")
truncateTestData(t)
}
func TestDeleteArtist(t *testing.T) {
ctx := context.Background()
// set up a lot of test data, 4 artists, 4 albums, 4 tracks, 10 listens
testDataForTopItems(t)
// Delete the artist
err := store.DeleteArtist(ctx, 1)
require.NoError(t, err)
// Verify artist was deleted
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artists
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected artist to be deleted")
// Verify artist's release was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM releases
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected artist's release to be deleted")
// Verify artist's track was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM tracks
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected artist's tracks to be deleted")
// Verify artist's listens was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM listens
WHERE track_id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected artist's listens to be deleted")
truncateTestData(t)
}

View file

@ -0,0 +1,70 @@
package psql
import (
"context"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/repository"
)
func (p *Psql) CountListens(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountListens(ctx, repository.CountListensParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}
func (p *Psql) CountTracks(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountTopTracks(ctx, repository.CountTopTracksParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}
func (p *Psql) CountAlbums(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountTopReleases(ctx, repository.CountTopReleasesParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}
func (p *Psql) CountArtists(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountTopArtists(ctx, repository.CountTopArtistsParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}
func (p *Psql) CountTimeListened(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountTimeListened(ctx, repository.CountTimeListenedParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}

View file

@ -0,0 +1,76 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCountListens(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountListens
period := db.PeriodWeek
count, err := store.CountListens(ctx, period)
require.NoError(t, err)
assert.Equal(t, int64(1), count, "expected listens count to match inserted data")
truncateTestData(t)
}
func TestCountTracks(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountTracks
period := db.PeriodMonth
count, err := store.CountTracks(ctx, period)
require.NoError(t, err)
assert.Equal(t, int64(2), count, "expected tracks count to match inserted data")
truncateTestData(t)
}
func TestCountAlbums(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountAlbums
period := db.PeriodYear
count, err := store.CountAlbums(ctx, period)
require.NoError(t, err)
assert.Equal(t, int64(3), count, "expected albums count to match inserted data")
truncateTestData(t)
}
func TestCountArtists(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountArtists
period := db.PeriodAllTime
count, err := store.CountArtists(ctx, period)
require.NoError(t, err)
assert.Equal(t, int64(4), count, "expected artists count to match inserted data")
truncateTestData(t)
}
func TestCountTimeListened(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountTimeListened
period := db.PeriodMonth
count, err := store.CountTimeListened(ctx, period)
require.NoError(t, err)
// 3 listens in past month, each 100 seconds
assert.Equal(t, int64(300), count, "expected total time listened to match inserted data")
truncateTestData(t)
}

View file

@ -0,0 +1,74 @@
package psql
import (
"context"
"encoding/json"
"errors"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
func (d *Psql) ImageHasAssociation(ctx context.Context, image uuid.UUID) (bool, error) {
_, err := d.q.GetReleaseByImageID(ctx, &image)
if err == nil {
return true, err
} else if !errors.Is(err, pgx.ErrNoRows) {
return false, err
}
_, err = d.q.GetArtistByImage(ctx, &image)
if err == nil {
return true, err
} else if !errors.Is(err, pgx.ErrNoRows) {
return false, err
}
return false, nil
}
func (d *Psql) GetImageSource(ctx context.Context, image uuid.UUID) (string, error) {
r, err := d.q.GetReleaseByImageID(ctx, &image)
if err == nil {
return r.ImageSource.String, err
} else if !errors.Is(err, pgx.ErrNoRows) {
return "", err
}
rr, err := d.q.GetArtistByImage(ctx, &image)
if err == nil {
return rr.ImageSource.String, err
} else if !errors.Is(err, pgx.ErrNoRows) {
return "", err
}
return "", nil
}
func (d *Psql) AlbumsWithoutImages(ctx context.Context, from int32) ([]*models.Album, error) {
l := logger.FromContext(ctx)
rows, err := d.q.GetReleasesWithoutImages(ctx, repository.GetReleasesWithoutImagesParams{
Limit: 20,
ID: from,
})
if err != nil {
return nil, err
}
albums := make([]*models.Album, len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", row.ID)
artists = nil
}
albums[i] = &models.Album{
ID: row.ID,
Image: row.Image,
Title: row.Title,
MbzID: row.MusicBrainzID,
VariousArtists: row.VariousArtists,
Artists: artists,
}
}
return albums, nil
}

View file

@ -0,0 +1,106 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/catalog"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestDataForImages(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id, image, image_source)
VALUES ('00000000-0000-0000-0000-000000000001', '11111111-1111-1111-1111-111111111111', 'User Upload'),
('00000000-0000-0000-0000-000000000002', NULL, NULL)`)
require.NoError(t, err)
// Insert artist aliases
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert albums
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id, image, image_source)
VALUES ('22222222-2222-2222-2222-222222222222', '33333333-3333-3333-3333-333333333333', 'Automatic'),
('44444444-4444-4444-4444-444444444444', NULL, NULL)`)
require.NoError(t, err)
// Insert release aliases
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Album One', 'Testing', true),
(2, 'Album Two', 'Testing', true)`)
require.NoError(t, err)
// Associate albums with artists
err = store.Exec(context.Background(),
`INSERT INTO artist_releases (artist_id, release_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
}
func TestImageHasAssociation(t *testing.T) {
ctx := context.Background()
setupTestDataForImages(t)
// Test image with association
imageID := uuid.MustParse("11111111-1111-1111-1111-111111111111")
hasAssociation, err := store.ImageHasAssociation(ctx, imageID)
require.NoError(t, err)
assert.True(t, hasAssociation, "expected image to have an association")
// Test image without association
imageID = uuid.MustParse("55555555-5555-5555-5555-555555555555")
hasAssociation, err = store.ImageHasAssociation(ctx, imageID)
require.NoError(t, err)
assert.False(t, hasAssociation, "expected image to have no association")
truncateTestData(t)
}
func TestGetImageSource(t *testing.T) {
ctx := context.Background()
setupTestDataForImages(t)
// Test image source for an album
imageID := uuid.MustParse("33333333-3333-3333-3333-333333333333")
source, err := store.GetImageSource(ctx, imageID)
require.NoError(t, err)
assert.Equal(t, "Automatic", source, "expected image source to match")
// Test image source for an artist
imageID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
source, err = store.GetImageSource(ctx, imageID)
require.NoError(t, err)
assert.Equal(t, catalog.ImageSourceUserUpload, source, "expected image source to match")
// Test image source for a non-existent image
imageID = uuid.MustParse("55555555-5555-5555-5555-555555555555")
source, err = store.GetImageSource(ctx, imageID)
require.NoError(t, err)
assert.Equal(t, "", source, "expected no image source for non-existent image")
truncateTestData(t)
}
func TestAlbumsWithoutImages(t *testing.T) {
ctx := context.Background()
setupTestDataForImages(t)
// Test albums without images
albums, err := store.AlbumsWithoutImages(ctx, 0)
require.NoError(t, err)
require.Len(t, albums, 1, "expected one album without an image")
assert.Equal(t, "Album Two", albums[0].Title, "expected album title to match")
truncateTestData(t)
}

218
internal/db/psql/listen.go Normal file
View file

@ -0,0 +1,218 @@
package psql
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
)
func (d *Psql) GetListensPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Listen], error) {
l := logger.FromContext(ctx)
offset := (opts.Page - 1) * opts.Limit
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
if err != nil {
return nil, err
}
if opts.Month == 0 && opts.Year == 0 {
// use period, not date range
t2 = time.Now()
t1 = db.StartTimeFromPeriod(opts.Period)
}
if opts.Limit == 0 {
opts.Limit = DefaultItemsPerPage
}
var listens []*models.Listen
var count int64
if opts.TrackID > 0 {
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetLastListensFromTrackPaginated(ctx, repository.GetLastListensFromTrackPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ID: int32(opts.TrackID),
})
if err != nil {
return nil, err
}
listens = make([]*models.Listen, len(rows))
for i, row := range rows {
t := &models.Listen{
Track: models.Track{
Title: row.TrackTitle,
ID: row.TrackID,
},
Time: row.ListenedAt,
}
err = json.Unmarshal(row.Artists, &t.Track.Artists)
if err != nil {
return nil, err
}
listens[i] = t
}
count, err = d.q.CountListensFromTrack(ctx, repository.CountListensFromTrackParams{
ListenedAt: t1,
ListenedAt_2: t2,
TrackID: int32(opts.TrackID),
})
if err != nil {
return nil, err
}
} else if opts.AlbumID > 0 {
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetLastListensFromReleasePaginated(ctx, repository.GetLastListensFromReleasePaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ReleaseID: int32(opts.AlbumID),
})
if err != nil {
return nil, err
}
listens = make([]*models.Listen, len(rows))
for i, row := range rows {
t := &models.Listen{
Track: models.Track{
Title: row.TrackTitle,
ID: row.TrackID,
},
Time: row.ListenedAt,
}
err = json.Unmarshal(row.Artists, &t.Track.Artists)
if err != nil {
return nil, err
}
listens[i] = t
}
count, err = d.q.CountListensFromRelease(ctx, repository.CountListensFromReleaseParams{
ListenedAt: t1,
ListenedAt_2: t2,
ReleaseID: int32(opts.AlbumID),
})
if err != nil {
return nil, err
}
} else if opts.ArtistID > 0 {
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetLastListensFromArtistPaginated(ctx, repository.GetLastListensFromArtistPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ArtistID: int32(opts.ArtistID),
})
if err != nil {
return nil, err
}
listens = make([]*models.Listen, len(rows))
for i, row := range rows {
t := &models.Listen{
Track: models.Track{
Title: row.TrackTitle,
ID: row.TrackID,
},
Time: row.ListenedAt,
}
err = json.Unmarshal(row.Artists, &t.Track.Artists)
if err != nil {
return nil, err
}
listens[i] = t
}
count, err = d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
ListenedAt: t1,
ListenedAt_2: t2,
ArtistID: int32(opts.ArtistID),
})
if err != nil {
return nil, err
}
} else {
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetLastListensPaginated(ctx, repository.GetLastListensPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
listens = make([]*models.Listen, len(rows))
for i, row := range rows {
t := &models.Listen{
Track: models.Track{
Title: row.TrackTitle,
ID: row.TrackID,
},
Time: row.ListenedAt,
}
err = json.Unmarshal(row.Artists, &t.Track.Artists)
if err != nil {
return nil, err
}
listens[i] = t
}
count, err = d.q.CountListens(ctx, repository.CountListensParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Database responded with %d tracks out of a total %d", len(rows), count)
}
return &db.PaginatedResponse[*models.Listen]{
Items: listens,
TotalCount: count,
ItemsPerPage: int32(opts.Limit),
HasNextPage: int64(offset+len(listens)) < count,
CurrentPage: int32(opts.Page),
}, nil
}
func (d *Psql) SaveListen(ctx context.Context, opts db.SaveListenOpts) error {
l := logger.FromContext(ctx)
if opts.TrackID == 0 {
return errors.New("required parameter TrackID missing")
}
if opts.Time.IsZero() {
opts.Time = time.Now()
}
var client *string
if opts.Client != "" {
client = &opts.Client
}
l.Debug().Msgf("Inserting listen for track with id %d at time %v into DB", opts.TrackID, opts.Time)
return d.q.InsertListen(ctx, repository.InsertListenParams{
TrackID: opts.TrackID,
ListenedAt: opts.Time,
UserID: opts.UserID,
Client: client,
})
}
func (d *Psql) DeleteListen(ctx context.Context, trackId int32, listenedAt time.Time) error {
l := logger.FromContext(ctx)
if trackId == 0 {
return errors.New("required parameter 'trackId' missing")
}
l.Debug().Msgf("Deleting listen from track %d at time %s from DB", trackId, listenedAt)
return d.q.DeleteListen(ctx, repository.DeleteListenParams{
TrackID: trackId,
ListenedAt: listenedAt,
})
}

View file

@ -0,0 +1,109 @@
package psql
import (
"context"
"errors"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/repository"
)
func (d *Psql) GetListenActivity(ctx context.Context, opts db.ListenActivityOpts) ([]db.ListenActivityItem, error) {
l := logger.FromContext(ctx)
if opts.Month != 0 && opts.Year == 0 {
return nil, errors.New("year must be specified with month")
}
// Default to range = 12 if not set
if opts.Range == 0 {
opts.Range = db.DefaultRange
}
t1, t2 := db.ListenActivityOptsToTimes(opts)
var listenActivity []db.ListenActivityItem
if opts.AlbumID > 0 {
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for release group %d",
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.AlbumID)
rows, err := d.q.ListenActivityForRelease(ctx, repository.ListenActivityForReleaseParams{
Column1: t1,
Column2: t2,
Column3: stepToInterval(opts.Step),
ReleaseID: opts.AlbumID,
})
if err != nil {
return nil, err
}
listenActivity = make([]db.ListenActivityItem, len(rows))
for i, row := range rows {
t := db.ListenActivityItem{
Start: row.BucketStart,
Listens: row.ListenCount,
}
listenActivity[i] = t
}
l.Debug().Msgf("Database responded with %d steps", len(rows))
} else if opts.ArtistID > 0 {
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for artist %d",
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.ArtistID)
rows, err := d.q.ListenActivityForArtist(ctx, repository.ListenActivityForArtistParams{
Column1: t1,
Column2: t2,
Column3: stepToInterval(opts.Step),
ArtistID: opts.ArtistID,
})
if err != nil {
return nil, err
}
listenActivity = make([]db.ListenActivityItem, len(rows))
for i, row := range rows {
t := db.ListenActivityItem{
Start: row.BucketStart,
Listens: row.ListenCount,
}
listenActivity[i] = t
}
l.Debug().Msgf("Database responded with %d steps", len(rows))
} else if opts.TrackID > 0 {
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for track %d",
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.TrackID)
rows, err := d.q.ListenActivityForTrack(ctx, repository.ListenActivityForTrackParams{
Column1: t1,
Column2: t2,
Column3: stepToInterval(opts.Step),
ID: opts.TrackID,
})
if err != nil {
return nil, err
}
listenActivity = make([]db.ListenActivityItem, len(rows))
for i, row := range rows {
t := db.ListenActivityItem{
Start: row.BucketStart,
Listens: row.ListenCount,
}
listenActivity[i] = t
}
l.Debug().Msgf("Database responded with %d steps", len(rows))
} else {
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v",
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"))
rows, err := d.q.ListenActivity(ctx, repository.ListenActivityParams{
Column1: t1,
Column2: t2,
Column3: stepToInterval(opts.Step),
})
if err != nil {
return nil, err
}
listenActivity = make([]db.ListenActivityItem, len(rows))
for i, row := range rows {
t := db.ListenActivityItem{
Start: row.BucketStart,
Listens: row.ListenCount,
}
listenActivity[i] = t
}
l.Debug().Msgf("Database responded with %d steps", len(rows))
}
return listenActivity, nil
}

View file

@ -0,0 +1,211 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func flattenListenCounts(items []db.ListenActivityItem) []int64 {
ret := make([]int64, len(items))
for i, v := range items {
ret[i] = v.Listens
}
return ret
}
func TestListenActivity(t *testing.T) {
truncateTestData(t)
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
// Move artist names into artist_aliases
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000011'),
('00000000-0000-0000-0000-000000000022')`)
require.NoError(t, err)
// Move release titles into release_aliases
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Release One', 'Testing', true),
(2, 'Release Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('11111111-1111-1111-1111-111111111111', 1),
('22222222-2222-2222-2222-222222222222', 2)`)
require.NoError(t, err)
// Move track titles into track_aliases
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Associate tracks with artists
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
// Insert listens
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 day'),
(1, 1, NOW() - INTERVAL '2 days'),
(1, 1, NOW() - INTERVAL '1 week 1 day'),
(1, 1, NOW() - INTERVAL '1 month 1 day'),
(1, 1, NOW() - INTERVAL '1 year 1 day'),
(1, 2, NOW() - INTERVAL '1 day'),
(1, 2, NOW() - INTERVAL '2 days'),
(1, 2, NOW() - INTERVAL '1 week 1 day'),
(1, 2, NOW() - INTERVAL '1 month 1 day'),
(1, 2, NOW() - INTERVAL '1 year 1 day')`)
require.NoError(t, err)
ctx := context.Background()
// Test for opts.Step = db.StepDay
activity, err := store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepDay})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0}, flattenListenCounts(activity))
// Truncate listens table and insert specific dates for testing opts.Step = db.StepMonth
err = store.Exec(context.Background(), `TRUNCATE TABLE listens`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 month'),
(1, 1, NOW() - INTERVAL '2 months'),
(1, 1, NOW() - INTERVAL '3 months'),
(1, 2, NOW() - INTERVAL '1 month'),
(1, 2, NOW() - INTERVAL '2 months')`)
require.NoError(t, err)
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepMonth, Range: 8})
require.NoError(t, err)
require.Len(t, activity, 8)
assert.Equal(t, []int64{0, 0, 0, 0, 1, 2, 2, 0}, flattenListenCounts(activity))
// Truncate listens table and insert specific dates for testing opts.Step = db.StepYear
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 year'),
(1, 1, NOW() - INTERVAL '2 years'),
(1, 2, NOW() - INTERVAL '1 year'),
(1, 2, NOW() - INTERVAL '3 years')`)
require.NoError(t, err)
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepYear})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 0}, flattenListenCounts(activity))
// Truncate and insert data for a specific month/year
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
require.NoError(t, err)
err = store.Exec(context.Background(), `
INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, DATE '2024-03-10'),
(1, 2, DATE '2024-03-20')`)
require.NoError(t, err)
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
Month: 3,
Year: 2024,
})
require.NoError(t, err)
require.Len(t, activity, 31) // number of days in march
assert.EqualValues(t, 1, activity[8].Listens)
assert.EqualValues(t, 1, activity[18].Listens)
// Truncate and insert listens associated with two different albums
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
require.NoError(t, err)
err = store.Exec(context.Background(), `
INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 day'), (1, 1, NOW() - INTERVAL '2 days'),
(1, 2, NOW() - INTERVAL '1 day')`)
require.NoError(t, err)
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
AlbumID: 1, // Track 1 only
})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0}, flattenListenCounts(activity))
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
TrackID: 1, // Track 1 only
})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0}, flattenListenCounts(activity))
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
ArtistID: 2, // Should only include listens to Track 2
})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}, flattenListenCounts(activity))
// month without year is disallowed
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
Month: 5,
})
require.Error(t, err)
// invalid options
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Year: -10,
})
require.Error(t, err)
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Year: 2025,
Month: -10,
})
require.Error(t, err)
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Range: -1,
})
require.Error(t, err)
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
AlbumID: -1,
})
require.Error(t, err)
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
ArtistID: -1,
})
require.Error(t, err)
}

View file

@ -0,0 +1,219 @@
package psql_test
import (
"context"
"testing"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testDataForListens(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
// Insert artist aliases
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000011'),
('00000000-0000-0000-0000-000000000022')`)
require.NoError(t, err)
// Insert release aliases
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Release One', 'Testing', true),
(2, 'Release Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('11111111-1111-1111-1111-111111111111', 1),
('22222222-2222-2222-2222-222222222222', 2)`)
require.NoError(t, err)
// Insert track aliases
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Insert artist track associations
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (track_id, artist_id)
VALUES (1, 1),
(2, 2)`)
require.NoError(t, err)
}
func TestGetListens(t *testing.T) {
testDataForTopItems(t)
ctx := context.Background()
// Test valid
resp, err := store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 10)
assert.Equal(t, int64(10), resp.TotalCount)
require.Len(t, resp.Items[0].Track.Artists, 1)
require.Len(t, resp.Items[1].Track.Artists, 1)
// ensure tracks are in the right order (time, desc)
assert.Equal(t, "Artist Four", resp.Items[0].Track.Artists[0].Name)
assert.Equal(t, "Artist Three", resp.Items[1].Track.Artists[0].Name)
// Test pagination
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
require.Len(t, resp.Items[0].Track.Artists, 1)
assert.Equal(t, true, resp.HasNextPage)
assert.EqualValues(t, 2, resp.CurrentPage)
assert.EqualValues(t, 1, resp.ItemsPerPage)
assert.EqualValues(t, 10, resp.TotalCount)
assert.Equal(t, "Artist Three", resp.Items[0].Track.Artists[0].Name)
// Test page out of range
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 10, Page: 10, Period: db.PeriodAllTime})
require.NoError(t, err)
assert.Empty(t, resp.Items)
assert.False(t, resp.HasNextPage)
// Test invalid inputs
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
assert.Error(t, err)
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
assert.Error(t, err)
// Test specify period
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
// should default to PeriodDay
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
require.NoError(t, err)
require.Len(t, resp.Items, 6)
assert.Equal(t, int64(6), resp.TotalCount)
// Test filter by artists, releases, and tracks
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, ArtistID: 1})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, TrackID: 3})
require.NoError(t, err)
require.Len(t, resp.Items, 2)
assert.Equal(t, int64(2), resp.TotalCount)
// when both artistID and albumID are specified, artist id is ignored
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2, ArtistID: 1})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
// Test specify dates
testDataAbsoluteListenTimes(t)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Year: 2023})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
// invalid, year required with month
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Month: 10})
require.Error(t, err)
}
func TestSaveListen(t *testing.T) {
testDataForListens(t)
ctx := context.Background()
// Test SaveListen with valid inputs
err := store.SaveListen(ctx, db.SaveListenOpts{
TrackID: 1,
Time: time.Now(),
UserID: 1,
})
require.NoError(t, err)
// Verify the listen was saved
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM listens
WHERE track_id = $1
)`, 1)
require.NoError(t, err)
assert.True(t, exists, "expected listen to exist")
// Test SaveListen with missing TrackID
err = store.SaveListen(ctx, db.SaveListenOpts{
TrackID: 0,
Time: time.Now(),
})
assert.Error(t, err)
}
func TestDeleteListen(t *testing.T) {
testDataForListens(t)
ctx := context.Background()
err := store.Exec(ctx, `
INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, to_timestamp(1749464138.0))`)
require.NoError(t, err)
err = store.DeleteListen(ctx, 1, time.Unix(1749464138, 0))
require.NoError(t, err)
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM listens
WHERE track_id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected listen to be deleted")
}

109
internal/db/psql/merge.go Normal file
View file

@ -0,0 +1,109 @@
package psql
import (
"context"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/repository"
"github.com/jackc/pgx/v5"
)
func (d *Psql) MergeTracks(ctx context.Context, fromId, toId int32) error {
l := logger.FromContext(ctx)
l.Info().Msgf("Merging track %d into track %d", fromId, toId)
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
err = qtx.UpdateTrackIdForListens(ctx, repository.UpdateTrackIdForListensParams{
TrackID: fromId,
TrackID_2: toId,
})
if err != nil {
return err
}
err = qtx.CleanOrphanedEntries(ctx)
if err != nil {
l.Err(err).Msg("Failed to clean orphaned entries")
return err
}
return tx.Commit(ctx)
}
func (d *Psql) MergeAlbums(ctx context.Context, fromId, toId int32) error {
l := logger.FromContext(ctx)
l.Info().Msgf("Merging album %d into album %d", fromId, toId)
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
err = qtx.UpdateReleaseForAll(ctx, repository.UpdateReleaseForAllParams{
ReleaseID: fromId,
ReleaseID_2: toId,
})
if err != nil {
return err
}
err = qtx.CleanOrphanedEntries(ctx)
if err != nil {
l.Err(err).Msg("Failed to clean orphaned entries")
return err
}
return tx.Commit(ctx)
}
func (d *Psql) MergeArtists(ctx context.Context, fromId, toId int32) error {
l := logger.FromContext(ctx)
l.Info().Msgf("Merging artist %d into artist %d", fromId, toId)
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
err = qtx.DeleteConflictingArtistTracks(ctx, repository.DeleteConflictingArtistTracksParams{
ArtistID: fromId,
ArtistID_2: toId,
})
if err != nil {
l.Err(err).Msg("Failed to delete conflicting artist tracks")
return err
}
err = qtx.DeleteConflictingArtistReleases(ctx, repository.DeleteConflictingArtistReleasesParams{
ArtistID: fromId,
ArtistID_2: toId,
})
if err != nil {
l.Err(err).Msg("Failed to delete conflicting artist releases")
return err
}
err = qtx.UpdateArtistTracks(ctx, repository.UpdateArtistTracksParams{
ArtistID: fromId,
ArtistID_2: toId,
})
if err != nil {
l.Err(err).Msg("Failed to update artist tracks")
return err
}
err = qtx.UpdateArtistReleases(ctx, repository.UpdateArtistReleasesParams{
ArtistID: fromId,
ArtistID_2: toId,
})
if err != nil {
l.Err(err).Msg("Failed to update artist releases")
return err
}
err = qtx.CleanOrphanedEntries(ctx)
if err != nil {
l.Err(err).Msg("Failed to clean orphaned entries")
return err
}
return tx.Commit(ctx)
}

View file

@ -0,0 +1,124 @@
package psql_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestDataForMerge(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert albums
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('11111111-1111-1111-1111-111111111111'),
('22222222-2222-2222-2222-222222222222')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Album One', 'Testing', true),
(2, 'Album Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('33333333-3333-3333-3333-333333333333', 1),
('44444444-4444-4444-4444-444444444444', 2)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Associate artists with albums and tracks
err = store.Exec(context.Background(),
`INSERT INTO artist_releases (artist_id, release_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
// Insert listens
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 day'),
(1, 2, NOW() - INTERVAL '2 days')`)
require.NoError(t, err)
}
func TestMergeTracks(t *testing.T) {
ctx := context.Background()
setupTestDataForMerge(t)
// Merge Track 1 into Track 2
err := store.MergeTracks(ctx, 1, 2)
require.NoError(t, err)
// Verify listens are updated
var count int
count, err = store.Count(ctx, `SELECT COUNT(*) FROM listens WHERE track_id = 2`)
require.NoError(t, err)
assert.Equal(t, 2, count, "expected all listens to be merged into Track 2")
truncateTestData(t)
}
func TestMergeAlbums(t *testing.T) {
ctx := context.Background()
setupTestDataForMerge(t)
// Merge Album 1 into Album 2
err := store.MergeAlbums(ctx, 1, 2)
require.NoError(t, err)
// Verify tracks are updated
var count int
count, err = store.Count(ctx, `SELECT COUNT(*) FROM tracks WHERE release_id = 2`)
require.NoError(t, err)
assert.Equal(t, 2, count, "expected all tracks to be merged into Album 2")
truncateTestData(t)
}
func TestMergeArtists(t *testing.T) {
ctx := context.Background()
setupTestDataForMerge(t)
// Merge Artist 1 into Artist 2
err := store.MergeArtists(ctx, 1, 2)
require.NoError(t, err)
// Verify artist associations are updated
var count int
count, err = store.Count(ctx, `SELECT COUNT(*) FROM artist_tracks WHERE artist_id = 2`)
require.NoError(t, err)
assert.Equal(t, 2, count, "expected all tracks to be associated with Artist 2")
count, err = store.Count(ctx, `SELECT COUNT(*) FROM artist_releases WHERE artist_id = 2`)
require.NoError(t, err)
assert.Equal(t, 2, count, "expected all releases to be associated with Artist 2")
truncateTestData(t)
}

119
internal/db/psql/psql.go Normal file
View file

@ -0,0 +1,119 @@
// package psql implements the db.DB interface using psx and a sql generated repository
package psql
import (
"context"
"database/sql"
"fmt"
"path/filepath"
"runtime"
"time"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/repository"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/pressly/goose/v3"
)
const (
DefaultItemsPerPage = 20
)
type Psql struct {
q *repository.Queries
conn *pgxpool.Pool
}
func New() (*Psql, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
config, err := pgxpool.ParseConfig(cfg.DatabaseUrl())
if err != nil {
return nil, fmt.Errorf("failed to parse pgx config: %w", err)
}
config.ConnConfig.ConnectTimeout = 15 * time.Second
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to create pgx pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("database not reachable: %w", err)
}
sqlDB, err := sql.Open("pgx", cfg.DatabaseUrl())
if err != nil {
return nil, fmt.Errorf("failed to open db for migrations: %w", err)
}
_, filename, _, ok := runtime.Caller(0)
if !ok {
return nil, fmt.Errorf("unable to get caller info")
}
migrationsPath := filepath.Join(filepath.Dir(filename), "..", "..", "..", "db", "migrations")
if err := goose.Up(sqlDB, migrationsPath); err != nil {
return nil, fmt.Errorf("goose failed: %w", err)
}
_ = sqlDB.Close()
return &Psql{
q: repository.New(pool),
conn: pool,
}, nil
}
// Not part of the DB interface this package implements. Only used for testing.
func (d *Psql) Exec(ctx context.Context, query string, args ...any) error {
_, err := d.conn.Exec(ctx, query, args...)
return err
}
// Not part of the DB interface this package implements. Only used for testing.
func (d *Psql) RowExists(ctx context.Context, query string, args ...any) (bool, error) {
var exists bool
err := d.conn.QueryRow(ctx, query, args...).Scan(&exists)
return exists, err
}
func (p *Psql) Count(ctx context.Context, query string, args ...any) (count int, err error) {
err = p.conn.QueryRow(ctx, query, args...).Scan(&count)
return
}
// Exposes p.conn.QueryRow. Only used for testing. Not part of the DB interface this package implements.
func (p *Psql) QueryRow(ctx context.Context, query string, args ...any) pgx.Row {
return p.conn.QueryRow(ctx, query, args...)
}
func (d *Psql) Close(ctx context.Context) {
d.conn.Close()
}
func (d *Psql) Ping(ctx context.Context) error {
return d.conn.Ping(ctx)
}
func stepToInterval(p db.StepInterval) pgtype.Interval {
var interval pgtype.Interval
switch p {
case db.StepDay:
interval.Days = 1
case db.StepWeek:
interval.Days = 7
case db.StepMonth:
interval.Months = 1
case db.StepYear:
interval.Months = 12
}
interval.Valid = true
return interval
}

View file

@ -0,0 +1,186 @@
package psql_test
import (
"context"
"fmt"
"log"
"testing"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db/psql"
_ "github.com/gabehf/koito/testing_init"
"github.com/ory/dockertest/v3"
"github.com/stretchr/testify/require"
)
var store *psql.Psql
func getTestGetenv(resource *dockertest.Resource) func(string) string {
return func(env string) string {
switch env {
case cfg.DATABASE_URL_ENV:
return fmt.Sprintf("postgres://postgres:secret@localhost:%s", resource.GetPort("5432/tcp"))
default:
return ""
}
}
}
func TestMain(m *testing.M) {
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not construct pool: %s", err)
}
// uses pool to try to connect to Docker
err = pool.Client.Ping()
if err != nil {
log.Fatalf("Could not connect to Docker: %s", err)
}
// pulls an image, creates a container based on it and runs it
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret"})
if err != nil {
log.Fatalf("Could not start resource: %s", err)
}
err = cfg.Load(getTestGetenv(resource))
if err != nil {
log.Fatalf("Could not load cfg: %s", err)
}
// exponential backoff-retry, because the application in the container might not be ready to accept connections yet
if err := pool.Retry(func() error {
var err error
store, err = psql.New()
if err != nil {
log.Println("Failed to connect to test database, retrying...")
return err
}
return store.Ping(context.Background())
}); err != nil {
log.Fatalf("Could not connect to database: %s", err)
}
// as of go1.15 testing.M returns the exit code of m.Run(), so it is safe to use defer here
defer func() {
if err := pool.Purge(resource); err != nil {
log.Fatalf("Could not purge resource: %s", err)
}
}()
// insert a user into the db with id 1 to use for tests
err = store.Exec(context.Background(), `INSERT INTO users (username, password) VALUES ('test', DECODE('abc123', 'hex'))`)
if err != nil {
log.Fatalf("Failed to insert test user: %v", err)
}
m.Run()
}
func testDataForTopItems(t *testing.T) {
truncateTestData(t)
// artist 1 has most listens older than 1 year
// artist 2 has most listens older than 1 month
// artist 3 has most listens older than 1 week
// artist 4 has least listens
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002'),
('00000000-0000-0000-0000-000000000003'),
('00000000-0000-0000-0000-000000000004')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true),
(3, 'Artist Three', 'Testing', true),
(4, 'Artist Four', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000011'),
('00000000-0000-0000-0000-000000000022'),
('00000000-0000-0000-0000-000000000033'),
('00000000-0000-0000-0000-000000000044')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Release One', 'Testing', true),
(2, 'Release Two', 'Testing', true),
(3, 'Release Three', 'Testing', true),
(4, 'Release Four', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO artist_releases (release_id, artist_id)
VALUES (1, 1), (2, 2), (3, 3), (4, 4)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id, duration)
VALUES ('11111111-1111-1111-1111-111111111111', 1, 100),
('22222222-2222-2222-2222-222222222222', 2, 100),
('33333333-3333-3333-3333-333333333333', 3, 100),
('44444444-4444-4444-4444-444444444444', 4, 100)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true),
(3, 'Track Three', 'Testing', true),
(4, 'Track Four', 'Testing', true)`)
require.NoError(t, err)
// Associate tracks with artists
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2), (3, 3), (4, 4)`)
require.NoError(t, err)
// Insert listens
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '2 years 1 day'),
(1, 1, NOW() - INTERVAL '2 years 2 days'),
(1, 1, NOW() - INTERVAL '2 years 3 days'),
(1, 1, NOW() - INTERVAL '2 years 4 days'),
(1, 2, NOW() - INTERVAL '2 months 1 day'),
(1, 2, NOW() - INTERVAL '2 months 2 days'),
(1, 2, NOW() - INTERVAL '2 months 3 days'),
(1, 3, NOW() - INTERVAL '2 weeks'),
(1, 3, NOW() - INTERVAL '2 weeks 1 day'),
(1, 4, NOW() - INTERVAL '2 days')`)
require.NoError(t, err)
}
func testDataAbsoluteListenTimes(t *testing.T) {
err := store.Exec(context.Background(),
`TRUNCATE listens`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, '2023-06-22 19:11:25-07'),
(1, 1, '2023-06-22 19:12:25-07'),
(1, 1, '2023-06-22 19:13:25-07'),
(1, 1, '2023-06-22 19:14:25-07'),
(1, 2, '2024-06-22 19:15:25-07'),
(1, 2, '2024-06-22 19:16:25-07'),
(1, 2, '2024-06-22 19:17:25-07'),
(1, 3, '2024-10-02 19:18:25-07'),
(1, 3, '2024-10-02 19:19:25-07'),
(1, 4, '2025-05-16 19:20:25-07')`)
require.NoError(t, err)
}

151
internal/db/psql/search.go Normal file
View file

@ -0,0 +1,151 @@
package psql
import (
"context"
"encoding/json"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/jackc/pgx/v5/pgtype"
)
const searchItemLimit = 5
const substringSearchLength = 6
func (d *Psql) SearchArtists(ctx context.Context, q string) ([]*models.Artist, error) {
if len(q) < substringSearchLength {
rows, err := d.q.SearchArtistsBySubstring(ctx, repository.SearchArtistsBySubstringParams{
Column1: pgtype.Text{String: q, Valid: true},
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Artist, len(rows))
for i, row := range rows {
ret[i] = &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Image: row.Image,
}
}
return ret, nil
} else {
rows, err := d.q.SearchArtists(ctx, repository.SearchArtistsParams{
Similarity: q,
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Artist, len(rows))
for i, row := range rows {
ret[i] = &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Image: row.Image,
}
}
return ret, nil
}
}
func (d *Psql) SearchAlbums(ctx context.Context, q string) ([]*models.Album, error) {
if len(q) < substringSearchLength {
rows, err := d.q.SearchReleasesBySubstring(ctx, repository.SearchReleasesBySubstringParams{
Column1: pgtype.Text{String: q, Valid: true},
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Album, len(rows))
for i, row := range rows {
ret[i] = &models.Album{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
VariousArtists: row.VariousArtists,
Image: row.Image,
}
err = json.Unmarshal(row.Artists, &ret[i].Artists)
if err != nil {
return nil, err
}
}
return ret, nil
} else {
rows, err := d.q.SearchReleases(ctx, repository.SearchReleasesParams{
Similarity: q,
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Album, len(rows))
for i, row := range rows {
ret[i] = &models.Album{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
VariousArtists: row.VariousArtists,
Image: row.Image,
}
err = json.Unmarshal(row.Artists, &ret[i].Artists)
if err != nil {
return nil, err
}
}
return ret, nil
}
}
func (d *Psql) SearchTracks(ctx context.Context, q string) ([]*models.Track, error) {
if len(q) < substringSearchLength {
rows, err := d.q.SearchTracksBySubstring(ctx, repository.SearchTracksBySubstringParams{
Column1: pgtype.Text{String: q, Valid: true},
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Track, len(rows))
for i, row := range rows {
ret[i] = &models.Track{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
Image: row.Image,
}
err = json.Unmarshal(row.Artists, &ret[i].Artists)
if err != nil {
return nil, err
}
}
return ret, nil
} else {
rows, err := d.q.SearchTracks(ctx, repository.SearchTracksParams{
Similarity: q,
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Track, len(rows))
for i, row := range rows {
ret[i] = &models.Track{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
Image: row.Image,
}
err = json.Unmarshal(row.Artists, &ret[i].Artists)
if err != nil {
return nil, err
}
}
return ret, nil
}
}

View file

@ -0,0 +1,116 @@
package psql_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestDataForSearch(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One With A Long Name', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert albums
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id, various_artists)
VALUES ('11111111-1111-1111-1111-111111111111', false),
('22222222-2222-2222-2222-222222222222', true)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Album One With A Long Name', 'Testing', true),
(2, 'Album Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('33333333-3333-3333-3333-333333333333', 1),
('44444444-4444-4444-4444-444444444444', 2)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One With A Long Name', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Associate artists with albums and tracks
err = store.Exec(context.Background(),
`INSERT INTO artist_releases (artist_id, release_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
}
func TestSearchArtists(t *testing.T) {
ctx := context.Background()
setupTestDataForSearch(t)
// Search for "Artist One With A Long Name"
results, err := store.SearchArtists(ctx, "Artist One With A Long Name")
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, "Artist One With A Long Name", results[0].Name)
// Search for substring "Artist"
results, err = store.SearchArtists(ctx, "Arti")
require.NoError(t, err)
require.Len(t, results, 2)
truncateTestData(t)
}
func TestSearchAlbums(t *testing.T) {
ctx := context.Background()
setupTestDataForSearch(t)
// Search for "Album One With A Long Name"
results, err := store.SearchAlbums(ctx, "Album One With A Long Name")
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, "Album One With A Long Name", results[0].Title)
// Search for substring "Album"
results, err = store.SearchAlbums(ctx, "Albu")
require.NoError(t, err)
require.Len(t, results, 2)
assert.NotNil(t, results[0].Artists)
truncateTestData(t)
}
func TestSearchTracks(t *testing.T) {
ctx := context.Background()
setupTestDataForSearch(t)
// Search for "Track One With A Long Name"
results, err := store.SearchTracks(ctx, "Track One With A Long Name")
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, "Track One With A Long Name", results[0].Title)
// Search for substring "Track"
results, err = store.SearchTracks(ctx, "Trac")
require.NoError(t, err)
require.Len(t, results, 2)
assert.NotNil(t, results[0].Artists)
truncateTestData(t)
}

View file

@ -0,0 +1,59 @@
package psql
import (
"context"
"errors"
"time"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
func (d *Psql) SaveSession(ctx context.Context, userID int32, expiresAt time.Time, persistent bool) (*models.Session, error) {
session, err := d.q.InsertSession(ctx, repository.InsertSessionParams{
ID: uuid.New(),
UserID: userID,
ExpiresAt: expiresAt,
Persistent: persistent,
})
if err != nil {
return nil, err
}
return &models.Session{
ID: session.ID,
UserID: session.UserID,
CreatedAt: session.CreatedAt,
ExpiresAt: session.ExpiresAt,
Persistent: session.Persistent,
}, nil
}
func (d *Psql) RefreshSession(ctx context.Context, sessionId uuid.UUID, expiresAt time.Time) error {
return d.q.UpdateSessionExpiry(ctx, repository.UpdateSessionExpiryParams{
ID: sessionId,
ExpiresAt: expiresAt,
})
}
func (d *Psql) DeleteSession(ctx context.Context, sessionId uuid.UUID) error {
return d.q.DeleteSession(ctx, sessionId)
}
// Returns nil, nil when no database entries are found
func (d *Psql) GetUserBySession(ctx context.Context, sessionId uuid.UUID) (*models.User, error) {
row, err := d.q.GetUserBySession(ctx, sessionId)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &models.User{
ID: row.ID,
Username: row.Username,
Password: row.Password,
Role: models.UserRole(row.Role),
}, nil
}

View file

@ -0,0 +1,101 @@
package psql_test
import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func truncateTestDataForSessions(t *testing.T) {
err := store.Exec(context.Background(),
`TRUNCATE
sessions
RESTART IDENTITY CASCADE`,
)
require.NoError(t, err)
}
func TestSaveSession(t *testing.T) {
ctx := context.Background()
// Save a session for the user
expiresAt := time.Now().Add(24 * time.Hour).UTC()
session, err := store.SaveSession(ctx, 1, expiresAt, true)
require.NoError(t, err)
require.NotNil(t, session)
assert.Equal(t, int32(1), session.UserID)
assert.Equal(t, true, session.Persistent)
assert.WithinDuration(t, expiresAt, session.ExpiresAt, time.Second)
truncateTestDataForSessions(t)
}
func TestRefreshSession(t *testing.T) {
ctx := context.Background()
// Save a session first
expiresAt := time.Now().Add(-1 * time.Minute)
session, err := store.SaveSession(ctx, 1, expiresAt, true)
require.NoError(t, err)
// Refresh the session expiry
newExpiresAt := time.Now().Add(48 * time.Hour)
err = store.RefreshSession(ctx, session.ID, newExpiresAt)
require.NoError(t, err)
// Can only retrieve a session with an expiresAt > time.Now()
_, err = store.GetUserBySession(ctx, session.ID)
require.NoError(t, err)
truncateTestDataForSessions(t)
}
func TestDeleteSession(t *testing.T) {
ctx := context.Background()
// Save a session first
expiresAt := time.Now().Add(24 * time.Hour)
session, err := store.SaveSession(ctx, 1, expiresAt, true)
require.NoError(t, err)
// Delete the session
err = store.DeleteSession(ctx, session.ID)
require.NoError(t, err)
// Verify the session was deleted
var count int
count, err = store.Count(ctx, `SELECT COUNT(*) FROM sessions WHERE id = $1`, session.ID)
require.NoError(t, err)
assert.Equal(t, 0, count)
truncateTestDataForSessions(t)
}
func TestGetUserBySession(t *testing.T) {
ctx := context.Background()
// Save a session first
expiresAt := time.Now().Add(24 * time.Hour)
session, err := store.SaveSession(ctx, 1, expiresAt, true)
require.NoError(t, err)
// Get the user by session
user, err := store.GetUserBySession(ctx, session.ID)
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, int32(1), user.ID)
assert.Equal(t, "test", user.Username)
assert.Equal(t, []uint8([]byte{0xab, 0xc1, 0x23}), user.Password)
assert.Equal(t, "user", string(user.Role))
// Test for a non-existent session
nonExistentSessionID := uuid.New()
user, err = store.GetUserBySession(ctx, nonExistentSessionID)
require.NoError(t, err)
assert.Nil(t, user)
truncateTestDataForSessions(t)
}

View file

@ -0,0 +1,119 @@
package psql
import (
"context"
"encoding/json"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
)
func (d *Psql) GetTopAlbumsPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Album], error) {
l := logger.FromContext(ctx)
offset := (opts.Page - 1) * opts.Limit
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
if err != nil {
return nil, err
}
if opts.Month == 0 && opts.Year == 0 {
// use period, not date range
t2 = time.Now()
t1 = db.StartTimeFromPeriod(opts.Period)
}
if opts.Limit == 0 {
opts.Limit = DefaultItemsPerPage
}
var rgs []*models.Album
var count int64
if opts.ArtistID != 0 {
l.Debug().Msgf("Fetching top %d albums from artist id %d with period %s on page %d from range %v to %v",
opts.Limit, opts.ArtistID, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopReleasesFromArtist(ctx, repository.GetTopReleasesFromArtistParams{
ArtistID: int32(opts.ArtistID),
Limit: int32(opts.Limit),
Offset: int32(offset),
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
rgs = make([]*models.Album, len(rows))
l.Debug().Msgf("Database responded with %d items", len(rows))
for i, v := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(v.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", v.ID)
artists = nil
}
rgs[i] = &models.Album{
ID: v.ID,
MbzID: v.MusicBrainzID,
Title: v.Title,
Image: v.Image,
Artists: artists,
VariousArtists: v.VariousArtists,
ListenCount: v.ListenCount,
}
}
count, err = d.q.CountReleasesFromArtist(ctx, int32(opts.ArtistID))
if err != nil {
return nil, err
}
} else {
l.Debug().Msgf("Fetching top %d albums with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopReleasesPaginated(ctx, repository.GetTopReleasesPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
rgs = make([]*models.Album, len(rows))
l.Debug().Msgf("Database responded with %d items", len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", row.ID)
artists = nil
}
t := &models.Album{
Title: row.Title,
MbzID: row.MusicBrainzID,
ID: row.ID,
Image: row.Image,
Artists: artists,
VariousArtists: row.VariousArtists,
ListenCount: row.ListenCount,
}
rgs[i] = t
}
count, err = d.q.CountTopReleases(ctx, repository.CountTopReleasesParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Database responded with %d albums out of a total %d", len(rows), count)
}
return &db.PaginatedResponse[*models.Album]{
Items: rgs,
TotalCount: count,
ItemsPerPage: int32(opts.Limit),
HasNextPage: int64(offset+len(rgs)) < count,
CurrentPage: int32(opts.Page),
}, nil
}

View file

@ -0,0 +1,103 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetTopAlbumsPaginated(t *testing.T) {
testDataForTopItems(t)
ctx := context.Background()
// Test valid
resp, err := store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
assert.Equal(t, "Release One", resp.Items[0].Title)
assert.Equal(t, "Release Two", resp.Items[1].Title)
assert.Equal(t, "Release Three", resp.Items[2].Title)
assert.Equal(t, "Release Four", resp.Items[3].Title)
// Test pagination
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, "Release Two", resp.Items[0].Title)
// Test page out of range
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Empty(t, resp.Items)
assert.False(t, resp.HasNextPage)
// Test invalid inputs
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
assert.Error(t, err)
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
assert.Error(t, err)
// Test specify period
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
// should default to PeriodDay
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Release Four", resp.Items[0].Title)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
require.NoError(t, err)
require.Len(t, resp.Items, 2)
assert.Equal(t, int64(2), resp.TotalCount)
assert.Equal(t, "Release Three", resp.Items[0].Title)
assert.Equal(t, "Release Four", resp.Items[1].Title)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
assert.Equal(t, "Release Two", resp.Items[0].Title)
assert.Equal(t, "Release Three", resp.Items[1].Title)
assert.Equal(t, "Release Four", resp.Items[2].Title)
// test specific artist
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear, ArtistID: 2})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Release Two", resp.Items[0].Title)
// Test specify dates
testDataAbsoluteListenTimes(t)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Year: 2023})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Release One", resp.Items[0].Title)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Release Two", resp.Items[0].Title)
// invalid, year required with month
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Month: 10})
require.Error(t, err)
}

View file

@ -0,0 +1,67 @@
package psql
import (
"context"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
)
func (d *Psql) GetTopArtistsPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Artist], error) {
l := logger.FromContext(ctx)
offset := (opts.Page - 1) * opts.Limit
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
if err != nil {
return nil, err
}
if opts.Month == 0 && opts.Year == 0 {
// use period, not date range
t2 = time.Now()
t1 = db.StartTimeFromPeriod(opts.Period)
}
if opts.Limit == 0 {
opts.Limit = DefaultItemsPerPage
}
l.Debug().Msgf("Fetching top %d artists with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopArtistsPaginated(ctx, repository.GetTopArtistsPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
rgs := make([]*models.Artist, len(rows))
for i, row := range rows {
t := &models.Artist{
Name: row.Name,
MbzID: row.MusicBrainzID,
ID: row.ID,
Image: row.Image,
ListenCount: row.ListenCount,
}
rgs[i] = t
}
count, err := d.q.CountTopArtists(ctx, repository.CountTopArtistsParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Database responded with %d artists out of a total %d", len(rows), count)
return &db.PaginatedResponse[*models.Artist]{
Items: rgs,
TotalCount: count,
ItemsPerPage: int32(opts.Limit),
HasNextPage: int64(offset+len(rgs)) < count,
CurrentPage: int32(opts.Page),
}, nil
}

View file

@ -0,0 +1,96 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetTopArtistsPaginated(t *testing.T) {
testDataForTopItems(t)
ctx := context.Background()
// Test valid
resp, err := store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
assert.Equal(t, "Artist One", resp.Items[0].Name)
assert.Equal(t, "Artist Two", resp.Items[1].Name)
assert.Equal(t, "Artist Three", resp.Items[2].Name)
assert.Equal(t, "Artist Four", resp.Items[3].Name)
// Test pagination
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, "Artist Two", resp.Items[0].Name)
// Test page out of range
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
require.NoError(t, err)
assert.Empty(t, resp.Items)
assert.False(t, resp.HasNextPage)
// Test invalid inputs
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
assert.Error(t, err)
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
assert.Error(t, err)
// Test specify period
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
// should default to PeriodDay
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Artist Four", resp.Items[0].Name)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
require.NoError(t, err)
require.Len(t, resp.Items, 2)
assert.Equal(t, int64(2), resp.TotalCount)
assert.Equal(t, "Artist Three", resp.Items[0].Name)
assert.Equal(t, "Artist Four", resp.Items[1].Name)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
assert.Equal(t, "Artist Two", resp.Items[0].Name)
assert.Equal(t, "Artist Three", resp.Items[1].Name)
assert.Equal(t, "Artist Four", resp.Items[2].Name)
// Test specify dates
testDataAbsoluteListenTimes(t)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Year: 2023})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Artist One", resp.Items[0].Name)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Artist Two", resp.Items[0].Name)
// invalid, year required with month
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Month: 10})
require.Error(t, err)
}

View file

@ -0,0 +1,160 @@
package psql
import (
"context"
"encoding/json"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
)
func (d *Psql) GetTopTracksPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Track], error) {
l := logger.FromContext(ctx)
offset := (opts.Page - 1) * opts.Limit
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
if err != nil {
return nil, err
}
if opts.Month == 0 && opts.Year == 0 {
// use period, not date range
t2 = time.Now()
t1 = db.StartTimeFromPeriod(opts.Period)
}
if opts.Limit == 0 {
opts.Limit = DefaultItemsPerPage
}
var tracks []*models.Track
var count int64
if opts.AlbumID > 0 {
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopTracksInReleasePaginated(ctx, repository.GetTopTracksInReleasePaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ReleaseID: int32(opts.AlbumID),
})
if err != nil {
return nil, err
}
tracks = make([]*models.Track, len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
artists = nil
}
t := &models.Track{
Title: row.Title,
MbzID: row.MusicBrainzID,
ID: row.ID,
ListenCount: row.ListenCount,
Image: row.Image,
AlbumID: row.ReleaseID,
Artists: artists,
}
tracks[i] = t
}
count, err = d.q.CountTopTracksByRelease(ctx, repository.CountTopTracksByReleaseParams{
ListenedAt: t1,
ListenedAt_2: t2,
ReleaseID: int32(opts.AlbumID),
})
if err != nil {
return nil, err
}
} else if opts.ArtistID > 0 {
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopTracksByArtistPaginated(ctx, repository.GetTopTracksByArtistPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ArtistID: int32(opts.ArtistID),
})
if err != nil {
return nil, err
}
tracks = make([]*models.Track, len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
artists = nil
}
t := &models.Track{
Title: row.Title,
MbzID: row.MusicBrainzID,
ID: row.ID,
Image: row.Image,
ListenCount: row.ListenCount,
AlbumID: row.ReleaseID,
Artists: artists,
}
tracks[i] = t
}
count, err = d.q.CountTopTracksByArtist(ctx, repository.CountTopTracksByArtistParams{
ListenedAt: t1,
ListenedAt_2: t2,
ArtistID: int32(opts.ArtistID),
})
if err != nil {
return nil, err
}
} else {
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopTracksPaginated(ctx, repository.GetTopTracksPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
tracks = make([]*models.Track, len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
artists = nil
}
t := &models.Track{
Title: row.Title,
MbzID: row.MusicBrainzID,
ID: row.ID,
Image: row.Image,
ListenCount: row.ListenCount,
AlbumID: row.ReleaseID,
Artists: artists,
}
tracks[i] = t
}
count, err = d.q.CountTopTracks(ctx, repository.CountTopTracksParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Database responded with %d tracks out of a total %d", len(rows), count)
}
return &db.PaginatedResponse[*models.Track]{
Items: tracks,
TotalCount: count,
ItemsPerPage: int32(opts.Limit),
HasNextPage: int64(offset+len(tracks)) < count,
CurrentPage: int32(opts.Page),
}, nil
}

View file

@ -0,0 +1,118 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetTopTracksPaginated(t *testing.T) {
testDataForTopItems(t)
ctx := context.Background()
// Test valid
resp, err := store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
assert.Equal(t, "Track One", resp.Items[0].Title)
assert.Equal(t, "Track Two", resp.Items[1].Title)
assert.Equal(t, "Track Three", resp.Items[2].Title)
assert.Equal(t, "Track Four", resp.Items[3].Title)
// ensure artists are included
require.Len(t, resp.Items[0].Artists, 1)
assert.Equal(t, "Artist One", resp.Items[0].Artists[0].Name)
// Test pagination
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, "Track Two", resp.Items[0].Title)
// Test page out of range
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
require.NoError(t, err)
assert.Empty(t, resp.Items)
assert.False(t, resp.HasNextPage)
// Test invalid inputs
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
assert.Error(t, err)
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
assert.Error(t, err)
// Test specify period
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
// should default to PeriodDay
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track Four", resp.Items[0].Title)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
require.NoError(t, err)
require.Len(t, resp.Items, 2)
assert.Equal(t, int64(2), resp.TotalCount)
assert.Equal(t, "Track Three", resp.Items[0].Title)
assert.Equal(t, "Track Four", resp.Items[1].Title)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
assert.Equal(t, "Track Two", resp.Items[0].Title)
assert.Equal(t, "Track Three", resp.Items[1].Title)
assert.Equal(t, "Track Four", resp.Items[2].Title)
// Test filter by artists and releases
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, ArtistID: 1})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track One", resp.Items[0].Title)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track Two", resp.Items[0].Title)
// when both artistID and albumID are specified, artist id is ignored
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2, ArtistID: 1})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track Two", resp.Items[0].Title)
// Test specify dates
testDataAbsoluteListenTimes(t)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Year: 2023})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track One", resp.Items[0].Title)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track Two", resp.Items[0].Title)
// invalid, year required with month
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Month: 10})
require.Error(t, err)
}

298
internal/db/psql/track.go Normal file
View file

@ -0,0 +1,298 @@
package psql
import (
"context"
"errors"
"strings"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
func (d *Psql) GetTrack(ctx context.Context, opts db.GetTrackOpts) (*models.Track, error) {
l := logger.FromContext(ctx)
var track models.Track
if opts.ID != 0 {
l.Debug().Msgf("Fetching track from DB with id %d", opts.ID)
t, err := d.q.GetTrack(ctx, opts.ID)
if err != nil {
return nil, err
}
track = models.Track{
ID: t.ID,
MbzID: t.MusicBrainzID,
Title: t.Title,
AlbumID: t.ReleaseID,
Image: t.Image,
Duration: t.Duration,
}
} else if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Fetching track from DB with MusicBrainz ID %s", opts.MusicBrainzID)
t, err := d.q.GetTrackByMbzID(ctx, &opts.MusicBrainzID)
if err != nil {
return nil, err
}
track = models.Track{
ID: t.ID,
MbzID: t.MusicBrainzID,
Title: t.Title,
AlbumID: t.ReleaseID,
Duration: t.Duration,
}
} else if len(opts.ArtistIDs) > 0 {
l.Debug().Msgf("Fetching track from DB with title '%s' and artist id(s) '%v'", opts.Title, opts.ArtistIDs)
t, err := d.q.GetTrackByTitleAndArtists(ctx, repository.GetTrackByTitleAndArtistsParams{
Title: opts.Title,
Column2: opts.ArtistIDs,
})
if err != nil {
return nil, err
}
track = models.Track{
ID: t.ID,
MbzID: t.MusicBrainzID,
Title: t.Title,
AlbumID: t.ReleaseID,
Duration: t.Duration,
}
} else {
return nil, errors.New("insufficient information to get track")
}
count, err := d.q.CountListensFromTrack(ctx, repository.CountListensFromTrackParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
TrackID: track.ID,
})
if err != nil {
l.Err(err).Msgf("Failed to get listen count for track with id %d", track.ID)
}
track.ListenCount = count
return &track, nil
}
func (d *Psql) SaveTrack(ctx context.Context, opts db.SaveTrackOpts) (*models.Track, error) {
// create track in DB
l := logger.FromContext(ctx)
var insertMbzID *uuid.UUID
if opts.RecordingMbzID != uuid.Nil {
insertMbzID = &opts.RecordingMbzID
}
if len(opts.ArtistIDs) < 1 {
return nil, errors.New("required parameter 'ArtistIDs' missing")
}
for _, aid := range opts.ArtistIDs {
if aid == 0 {
return nil, errors.New("none of 'ArtistIDs' may be 0")
}
}
if opts.AlbumID == 0 {
return nil, errors.New("required parameter 'AlbumID' missing")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return nil, err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
l.Debug().Msgf("Inserting new track '%s' into DB", opts.Title)
trackRow, err := qtx.InsertTrack(ctx, repository.InsertTrackParams{
MusicBrainzID: insertMbzID,
ReleaseID: opts.AlbumID,
})
if err != nil {
return nil, err
}
// insert associated artists
for _, aid := range opts.ArtistIDs {
err = qtx.AssociateArtistToTrack(ctx, repository.AssociateArtistToTrackParams{
ArtistID: aid,
TrackID: trackRow.ID,
})
if err != nil {
return nil, err
}
}
// insert primary alias
err = qtx.InsertTrackAlias(ctx, repository.InsertTrackAliasParams{
TrackID: trackRow.ID,
Alias: opts.Title,
Source: "Canonical",
IsPrimary: true,
})
if err != nil {
return nil, err
}
err = tx.Commit(ctx)
if err != nil {
return nil, err
}
return &models.Track{
ID: trackRow.ID,
MbzID: insertMbzID,
Title: opts.Title,
}, nil
}
func (d *Psql) UpdateTrack(ctx context.Context, opts db.UpdateTrackOpts) error {
l := logger.FromContext(ctx)
if opts.ID == 0 {
return errors.New("track id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Updating MusicBrainz ID for track %d", opts.ID)
err := qtx.UpdateTrackMbzID(ctx, repository.UpdateTrackMbzIDParams{
ID: opts.ID,
MusicBrainzID: &opts.MusicBrainzID,
})
if err != nil {
return err
}
}
if opts.Duration != 0 {
l.Debug().Msgf("Updating duration for track %d", opts.ID)
err := qtx.UpdateTrackDuration(ctx, repository.UpdateTrackDurationParams{
ID: opts.ID,
Duration: opts.Duration,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) SaveTrackAliases(ctx context.Context, id int32, aliases []string, source string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("track id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
existing, err := qtx.GetAllTrackAliases(ctx, id)
if err != nil {
return err
}
for _, v := range existing {
aliases = append(aliases, v.Alias)
}
utils.Unique(&aliases)
for _, alias := range aliases {
if strings.TrimSpace(alias) == "" {
return errors.New("aliases cannot be blank")
}
err = qtx.InsertTrackAlias(ctx, repository.InsertTrackAliasParams{
Alias: strings.TrimSpace(alias),
TrackID: id,
Source: source,
IsPrimary: false,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) DeleteTrack(ctx context.Context, id int32) error {
return d.q.DeleteTrack(ctx, id)
}
func (d *Psql) DeleteTrackAlias(ctx context.Context, id int32, alias string) error {
return d.q.DeleteTrackAlias(ctx, repository.DeleteTrackAliasParams{
TrackID: id,
Alias: alias,
})
}
func (d *Psql) GetAllTrackAliases(ctx context.Context, id int32) ([]models.Alias, error) {
rows, err := d.q.GetAllTrackAliases(ctx, id)
if err != nil {
return nil, err
}
aliases := make([]models.Alias, len(rows))
for i, row := range rows {
aliases[i] = models.Alias{
ID: id,
Alias: row.Alias,
Source: row.Source,
Primary: row.IsPrimary,
}
}
return aliases, nil
}
func (d *Psql) SetPrimaryTrackAlias(ctx context.Context, id int32, alias string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
// get all aliases
aliases, err := qtx.GetAllTrackAliases(ctx, id)
if err != nil {
return err
}
primary := ""
exists := false
for _, v := range aliases {
if v.Alias == alias {
exists = true
}
if v.IsPrimary {
primary = v.Alias
}
}
if primary == alias {
// no-op rename
return nil
}
if !exists {
return errors.New("alias does not exist")
}
err = qtx.SetTrackAliasPrimaryStatus(ctx, repository.SetTrackAliasPrimaryStatusParams{
TrackID: id,
Alias: alias,
IsPrimary: true,
})
if err != nil {
return err
}
err = qtx.SetTrackAliasPrimaryStatus(ctx, repository.SetTrackAliasPrimaryStatusParams{
TrackID: id,
Alias: primary,
IsPrimary: false,
})
if err != nil {
return err
}
return tx.Commit(ctx)
}

View file

@ -0,0 +1,213 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testDataForTracks(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
// Insert artist aliases
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000011'),
('00000000-0000-0000-0000-000000000022')`)
require.NoError(t, err)
// Insert release aliases
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Release Group One', 'Testing', true),
(2, 'Release Group Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('11111111-1111-1111-1111-111111111111', 1),
('22222222-2222-2222-2222-222222222222', 2)`)
require.NoError(t, err)
// Insert track aliases
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Associate tracks with artists
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
}
func TestGetTrack(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
// Test GetTrack by ID
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
require.NoError(t, err)
assert.Equal(t, int32(1), track.ID)
assert.Equal(t, "Track One", track.Title)
assert.Equal(t, uuid.MustParse("11111111-1111-1111-1111-111111111111"), *track.MbzID)
// Test GetTrack by MusicBrainzID
track, err = store.GetTrack(ctx, db.GetTrackOpts{MusicBrainzID: uuid.MustParse("22222222-2222-2222-2222-222222222222")})
require.NoError(t, err)
assert.Equal(t, int32(2), track.ID)
assert.Equal(t, "Track Two", track.Title)
// Test GetTrack by Title and ArtistIDs
track, err = store.GetTrack(ctx, db.GetTrackOpts{
Title: "Track One",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
assert.Equal(t, int32(1), track.ID)
assert.Equal(t, "Track One", track.Title)
// Test GetTrack with insufficient information
_, err = store.GetTrack(ctx, db.GetTrackOpts{})
assert.Error(t, err)
}
func TestSaveTrack(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
// Test SaveTrack with valid inputs
track, err := store.SaveTrack(ctx, db.SaveTrackOpts{
Title: "New Track",
ArtistIDs: []int32{1},
RecordingMbzID: uuid.MustParse("33333333-3333-3333-3333-333333333333"),
AlbumID: 1,
})
require.NoError(t, err)
assert.Equal(t, "New Track", track.Title)
assert.Equal(t, uuid.MustParse("33333333-3333-3333-3333-333333333333"), *track.MbzID)
// Verify artist associations exist
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_tracks
WHERE artist_id = $1 AND track_id = $2
)`, 1, track.ID)
require.NoError(t, err)
assert.True(t, exists, "expected artist association to exist")
// Verify alias exists
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM track_aliases
WHERE track_id = $1 AND is_primary = true
)`, track.ID)
require.NoError(t, err)
assert.True(t, exists, "expected primary alias to exist")
// Test SaveTrack with missing ArtistIDs
_, err = store.SaveTrack(ctx, db.SaveTrackOpts{
Title: "Invalid Track",
ArtistIDs: []int32{},
RecordingMbzID: uuid.MustParse("44444444-4444-4444-4444-444444444444"),
})
assert.Error(t, err)
// Test SaveTrack with invalid ArtistIDs
_, err = store.SaveTrack(ctx, db.SaveTrackOpts{
Title: "Invalid Track",
ArtistIDs: []int32{0},
RecordingMbzID: uuid.MustParse("55555555-5555-5555-5555-555555555555"),
})
assert.Error(t, err)
}
func TestUpdateTrack(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
newMbzID := uuid.MustParse("66666666-6666-6666-6666-666666666666")
newDuration := 100
err := store.UpdateTrack(ctx, db.UpdateTrackOpts{
ID: 1,
MusicBrainzID: newMbzID,
Duration: int32(newDuration),
})
require.NoError(t, err)
// Verify the update
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
require.NoError(t, err)
require.Equal(t, newMbzID, *track.MbzID)
require.EqualValues(t, newDuration, track.Duration)
// Test UpdateTrack with missing ID
err = store.UpdateTrack(ctx, db.UpdateTrackOpts{
ID: 0,
MusicBrainzID: newMbzID,
Duration: int32(newDuration),
})
assert.Error(t, err)
// Test UpdateTrack with nil MusicBrainz ID
err = store.UpdateTrack(ctx, db.UpdateTrackOpts{
ID: 1,
MusicBrainzID: uuid.Nil,
Duration: int32(newDuration),
})
assert.NoError(t, err) // No update should occur
}
func TestTrackAliases(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
err := store.SaveTrackAliases(ctx, 1, []string{"Alias One", "Alias Two"}, "Testing")
require.NoError(t, err)
aliases, err := store.GetAllTrackAliases(ctx, 1)
require.NoError(t, err)
assert.Len(t, aliases, 3)
err = store.SetPrimaryTrackAlias(ctx, 1, "Alias One")
require.NoError(t, err)
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
require.NoError(t, err)
assert.Equal(t, "Alias One", track.Title)
err = store.SetPrimaryTrackAlias(ctx, 1, "Fake Alias")
require.Error(t, err)
store.SetPrimaryTrackAlias(ctx, 1, "Track One")
}
func TestDeleteTrack(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
err := store.DeleteTrack(ctx, 2)
require.NoError(t, err)
_, err = store.Count(ctx, `SELECT * FROM tracks WHERE id = 2`)
require.ErrorIs(t, err, pgx.ErrNoRows) // no rows error
}

219
internal/db/psql/user.go Normal file
View file

@ -0,0 +1,219 @@
package psql
import (
"context"
"errors"
"regexp"
"strings"
"unicode/utf8"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/jackc/pgx/v5"
"golang.org/x/crypto/bcrypt"
)
// Returns nil, nil when no database entries are found
func (d *Psql) GetUserByUsername(ctx context.Context, username string) (*models.User, error) {
row, err := d.q.GetUserByUsername(ctx, strings.ToLower(username))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &models.User{
ID: row.ID,
Username: row.Username,
Password: row.Password,
Role: models.UserRole(row.Role),
}, nil
}
// Returns nil, nil when no database entries are found
func (d *Psql) GetUserByApiKey(ctx context.Context, key string) (*models.User, error) {
row, err := d.q.GetUserByApiKey(ctx, key)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &models.User{
ID: row.ID,
Username: row.Username,
Password: row.Password,
Role: models.UserRole(row.Role),
}, nil
}
func (d *Psql) SaveUser(ctx context.Context, opts db.SaveUserOpts) (*models.User, error) {
l := logger.FromContext(ctx)
err := ValidateUsername(opts.Username)
if err != nil {
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
return nil, err
}
pw, err := ValidateAndNormalizePassword(opts.Password)
if err != nil {
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
return nil, err
}
if opts.Role == "" {
opts.Role = models.UserRoleUser
}
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
if err != nil {
l.Err(err).Msg("Failed to generate hashed password")
return nil, err
}
u, err := d.q.InsertUser(ctx, repository.InsertUserParams{
Username: strings.ToLower(opts.Username),
Password: hashPw,
Role: repository.Role(opts.Role),
})
if err != nil {
return nil, err
}
return &models.User{
ID: u.ID,
Username: u.Username,
Role: models.UserRole(u.Role),
}, nil
}
func (d *Psql) SaveApiKey(ctx context.Context, opts db.SaveApiKeyOpts) (*models.ApiKey, error) {
row, err := d.q.InsertApiKey(ctx, repository.InsertApiKeyParams{
Key: opts.Key,
Label: opts.Label,
UserID: opts.UserID,
})
if err != nil {
return nil, err
}
return &models.ApiKey{
ID: row.ID,
UserID: row.UserID,
Key: row.Key,
Label: row.Label,
CreatedAt: row.CreatedAt.Time,
}, nil
}
func (d *Psql) UpdateUser(ctx context.Context, opts db.UpdateUserOpts) error {
l := logger.FromContext(ctx)
if opts.ID == 0 {
return errors.New("user id is required")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
if opts.Username != "" {
err := ValidateUsername(opts.Username)
if err != nil {
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
return err
}
err = qtx.UpdateUserUsername(ctx, repository.UpdateUserUsernameParams{
ID: opts.ID,
Username: opts.Username,
})
if err != nil {
return err
}
}
if opts.Password != "" {
pw, err := ValidateAndNormalizePassword(opts.Password)
if err != nil {
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
return err
}
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
if err != nil {
l.Err(err).Msg("Failed to generate hashed password")
return err
}
err = qtx.UpdateUserPassword(ctx, repository.UpdateUserPasswordParams{
ID: opts.ID,
Password: hashPw,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) GetApiKeysByUserID(ctx context.Context, id int32) ([]models.ApiKey, error) {
rows, err := d.q.GetAllApiKeysByUserID(ctx, id)
if err != nil {
return nil, err
}
keys := make([]models.ApiKey, len(rows))
for i, row := range rows {
keys[i] = models.ApiKey{
ID: row.ID,
Key: row.Key,
Label: row.Label,
UserID: row.UserID,
}
}
return keys, nil
}
func (d *Psql) UpdateApiKeyLabel(ctx context.Context, opts db.UpdateApiKeyLabelOpts) error {
return d.q.UpdateApiKeyLabel(ctx, repository.UpdateApiKeyLabelParams{
ID: opts.ID,
Label: opts.Label,
UserID: opts.UserID,
})
}
func (d *Psql) DeleteApiKey(ctx context.Context, id int32) error {
return d.q.DeleteApiKey(ctx, id)
}
func (d *Psql) CountUsers(ctx context.Context) (int64, error) {
return d.q.CountUsers(ctx)
}
const (
maxUsernameLength = 32
minUsernameLength = 1
maxPasswordLength = 128
minPasswordLength = 8
)
var usernameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
func ValidateUsername(username string) error {
length := utf8.RuneCountInString(username)
if length < minUsernameLength || length > maxUsernameLength {
return errors.New("username must be between 1 and 32 characters")
}
if !usernameRegex.MatchString(username) {
return errors.New("username can only contain [a-zA-Z0-9_.-]")
}
return nil
}
func ValidateAndNormalizePassword(password string) (string, error) {
length := utf8.RuneCountInString(password)
if length < minPasswordLength {
return "", errors.New("password must be at least 8 characters long")
}
if length > maxPasswordLength {
var truncated []rune
for i, r := range password {
if i >= maxPasswordLength {
break
}
truncated = append(truncated, r)
}
password = string(truncated)
}
return password, nil
}

View file

@ -0,0 +1,199 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
)
func setupTestDataForUsers(t *testing.T) {
truncateTestDataForUsers(t)
// Insert additional test users
err := store.Exec(context.Background(),
`INSERT INTO users (username, password, role)
VALUES ('test_user', $1, 'user'),
('admin_user', $1, 'admin')`, []byte("hashed_password"))
require.NoError(t, err)
}
func truncateTestDataForUsers(t *testing.T) {
err := store.Exec(context.Background(),
`DELETE FROM users WHERE id NOT IN (1)`,
)
require.NoError(t, err)
err = store.Exec(context.Background(),
`ALTER SEQUENCE users_id_seq RESTART WITH 2`,
)
require.NoError(t, err)
err = store.Exec(context.Background(),
`TRUNCATE api_keys RESTART IDENTITY CASCADE`,
)
require.NoError(t, err)
}
func TestGetUserByUsername(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Test fetching an existing user
user, err := store.GetUserByUsername(ctx, "test_user")
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, "test_user", user.Username)
assert.Equal(t, "user", string(user.Role))
// Test fetching a non-existent user
user, err = store.GetUserByUsername(ctx, "nonexistent_user")
require.NoError(t, err)
assert.Nil(t, user)
}
func TestGetUserByApiKey(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Insert an API key for the test user
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('test_key', 'Test Key', 2)`)
require.NoError(t, err)
// Test fetching a user by API key
user, err := store.GetUserByApiKey(ctx, "test_key")
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, int32(2), user.ID)
assert.Equal(t, "test_user", user.Username)
// Test fetching a user with a non-existent API key
user, err = store.GetUserByApiKey(ctx, "nonexistent_key")
require.NoError(t, err)
assert.Nil(t, user)
}
func TestSaveUser(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Save a new user
opts := db.SaveUserOpts{
Username: "new_user",
Password: "secure_password",
Role: "user",
}
user, err := store.SaveUser(ctx, opts)
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, "new_user", user.Username)
assert.Equal(t, "user", string(user.Role))
// Verify the password was hashed
var hashedPassword []byte
err = store.QueryRow(ctx, `SELECT password FROM users WHERE username = $1`, "new_user").Scan(&hashedPassword)
require.NoError(t, err)
assert.NoError(t, bcrypt.CompareHashAndPassword(hashedPassword, []byte(opts.Password)))
// Test validation failures
_, err = store.SaveUser(ctx, db.SaveUserOpts{
Username: "Q!@JH(F_H@#!*HF#*)&@",
Password: "testpassword12345",
})
assert.Error(t, err)
_, err = store.SaveUser(ctx, db.SaveUserOpts{
Username: "test_user",
Password: "<3",
})
assert.Error(t, err)
}
func TestSaveApiKey(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Save an API key for the test user
label := "New API Key"
opts := db.SaveApiKeyOpts{
Key: "new_api_key",
Label: label,
UserID: 2,
}
_, err := store.SaveApiKey(ctx, opts)
require.NoError(t, err)
// Verify the API key was saved
count, err := store.Count(ctx, `SELECT COUNT(*) FROM api_keys WHERE key = $1 AND user_id = $2`, opts.Key, opts.UserID)
require.NoError(t, err)
assert.Equal(t, 1, count)
}
func TestGetApiKeysByUserID(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Insert API keys for the test user
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES
('key1', 'Key 1', 2),
('key2', 'Key 2', 2)`)
require.NoError(t, err)
// Fetch API keys for the test user
keys, err := store.GetApiKeysByUserID(ctx, 2)
require.NoError(t, err)
require.Len(t, keys, 2)
assert.Equal(t, "key1", keys[0].Key)
assert.Equal(t, "key2", keys[1].Key)
}
func TestUpdateApiKeyLabel(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Insert an API key for the test user
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('key_to_update', 'Old Label', 2)`)
require.NoError(t, err)
// Update the API key label
opts := db.UpdateApiKeyLabelOpts{
ID: 1,
Label: "Updated Label",
UserID: 2,
}
err = store.UpdateApiKeyLabel(ctx, opts)
require.NoError(t, err)
// Verify the label was updated
var label string
err = store.QueryRow(ctx, `SELECT label FROM api_keys WHERE id = $1`, opts.ID).Scan(&label)
require.NoError(t, err)
assert.Equal(t, "Updated Label", label)
}
func TestDeleteApiKey(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Insert an API key for the test user
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('key_to_delete', 'Label', 2)`)
require.NoError(t, err)
// Delete the API key
err = store.DeleteApiKey(ctx, 1) // Assuming the ID is auto-generated and starts from 1
require.NoError(t, err)
// Verify the API key was deleted
count, err := store.Count(ctx, `SELECT COUNT(*) FROM api_keys WHERE id = $1`, 1)
require.NoError(t, err)
assert.Equal(t, 0, count)
}
func TestCountUsers(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Count the number of users
count, err := store.Count(ctx, `SELECT COUNT(*) FROM users`)
require.NoError(t, err)
assert.GreaterOrEqual(t, count, 3) // Special user + test users
}

26
internal/db/types.go Normal file
View file

@ -0,0 +1,26 @@
package db
import (
"time"
)
type InformationSource string
const (
InformationSourceInferred InformationSource = "Inferred"
InformationSourceMusicBrainz InformationSource = "MusicBrainz"
InformationSourceUserProvided InformationSource = "User"
)
type ListenActivityItem struct {
Start time.Time `json:"start_time"`
Listens int64 `json:"listens"`
}
type PaginatedResponse[T any] struct {
Items []T `json:"items"`
TotalCount int64 `json:"total_record_count"`
ItemsPerPage int32 `json:"items_per_page"`
HasNextPage bool `json:"has_next_page"`
CurrentPage int32 `json:"current_page"`
}