mirror of
https://github.com/gabehf/Koito.git
synced 2026-03-15 18:35:55 -07:00
chore: initial public commit
This commit is contained in:
commit
fc9054b78c
250 changed files with 32809 additions and 0 deletions
312
internal/db/psql/album.go
Normal file
312
internal/db/psql/album.go
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/gabehf/koito/internal/utils"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
func (d *Psql) GetAlbum(ctx context.Context, opts db.GetAlbumOpts) (*models.Album, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
|
||||
var row repository.ReleasesWithTitle
|
||||
var err error
|
||||
|
||||
if opts.ID != 0 {
|
||||
l.Debug().Msgf("Fetching album from DB with id %d", opts.ID)
|
||||
row, err = d.q.GetRelease(ctx, opts.ID)
|
||||
} else if opts.MusicBrainzID != uuid.Nil {
|
||||
l.Debug().Msgf("Fetching album from DB with MusicBrainz Release ID %s", opts.MusicBrainzID)
|
||||
row, err = d.q.GetReleaseByMbzID(ctx, &opts.MusicBrainzID)
|
||||
} else if opts.ArtistID != 0 && opts.Title != "" {
|
||||
l.Debug().Msgf("Fetching album from DB with artist_id %d and title %s", opts.ArtistID, opts.Title)
|
||||
row, err = d.q.GetReleaseByArtistAndTitle(ctx, repository.GetReleaseByArtistAndTitleParams{
|
||||
ArtistID: opts.ArtistID,
|
||||
Title: opts.Title,
|
||||
})
|
||||
} else if opts.ArtistID != 0 && len(opts.Titles) > 0 {
|
||||
l.Debug().Msgf("Fetching release group from DB with artist_id %d and titles %v", opts.ArtistID, opts.Titles)
|
||||
row, err = d.q.GetReleaseByArtistAndTitles(ctx, repository.GetReleaseByArtistAndTitlesParams{
|
||||
ArtistID: opts.ArtistID,
|
||||
Column1: opts.Titles,
|
||||
})
|
||||
} else {
|
||||
return nil, errors.New("insufficient information to get album")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count, err := d.q.CountListensFromRelease(ctx, repository.CountListensFromReleaseParams{
|
||||
ListenedAt: time.Unix(0, 0),
|
||||
ListenedAt_2: time.Now(),
|
||||
ReleaseID: row.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.Album{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Title: row.Title,
|
||||
Image: row.Image,
|
||||
VariousArtists: row.VariousArtists,
|
||||
ListenCount: count,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) SaveAlbum(ctx context.Context, opts db.SaveAlbumOpts) (*models.Album, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
var insertMbzID *uuid.UUID
|
||||
var insertImage *uuid.UUID
|
||||
if opts.MusicBrainzID != uuid.Nil {
|
||||
insertMbzID = &opts.MusicBrainzID
|
||||
}
|
||||
if opts.Image != uuid.Nil {
|
||||
insertImage = &opts.Image
|
||||
}
|
||||
if len(opts.ArtistIDs) < 1 {
|
||||
return nil, errors.New("required parameter 'ArtistIDs' missing")
|
||||
}
|
||||
for _, aid := range opts.ArtistIDs {
|
||||
if aid == 0 {
|
||||
return nil, errors.New("none of 'ArtistIDs' may be 0")
|
||||
}
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
l.Debug().Msgf("Inserting release '%s' into DB", opts.Title)
|
||||
r, err := qtx.InsertRelease(ctx, repository.InsertReleaseParams{
|
||||
MusicBrainzID: insertMbzID,
|
||||
VariousArtists: opts.VariousArtists,
|
||||
Image: insertImage,
|
||||
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, artistId := range opts.ArtistIDs {
|
||||
l.Debug().Msgf("Associating release '%s' to artist with ID %d", opts.Title, artistId)
|
||||
err = qtx.AssociateArtistToRelease(ctx, repository.AssociateArtistToReleaseParams{
|
||||
ArtistID: artistId,
|
||||
ReleaseID: r.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
l.Debug().Msgf("Saving canonical alias %s for release %d", opts.Title, r.ID)
|
||||
err = qtx.InsertReleaseAlias(ctx, repository.InsertReleaseAliasParams{
|
||||
ReleaseID: r.ID,
|
||||
Alias: opts.Title,
|
||||
Source: "Canonical",
|
||||
IsPrimary: true,
|
||||
})
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("Failed to save canonical alias for album %d", r.ID)
|
||||
}
|
||||
|
||||
err = tx.Commit(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.Album{
|
||||
ID: r.ID,
|
||||
MbzID: r.MusicBrainzID,
|
||||
Title: opts.Title,
|
||||
Image: r.Image,
|
||||
VariousArtists: r.VariousArtists,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) AddArtistsToAlbum(ctx context.Context, opts db.AddArtistsToAlbumOpts) error {
|
||||
l := logger.FromContext(ctx)
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
for _, id := range opts.ArtistIDs {
|
||||
err := qtx.AssociateArtistToRelease(ctx, repository.AssociateArtistToReleaseParams{
|
||||
ReleaseID: opts.AlbumID,
|
||||
ArtistID: id,
|
||||
})
|
||||
if err != nil {
|
||||
l.Error().Err(err).Msgf("Failed to associate release %d with artist %d", opts.AlbumID, id)
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) UpdateAlbum(ctx context.Context, opts db.UpdateAlbumOpts) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if opts.ID == 0 {
|
||||
return errors.New("missing album id")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
if opts.MusicBrainzID != uuid.Nil {
|
||||
l.Debug().Msgf("Updating release with ID %d with MusicBrainz ID %s", opts.ID, opts.MusicBrainzID)
|
||||
err := qtx.UpdateReleaseMbzID(ctx, repository.UpdateReleaseMbzIDParams{
|
||||
ID: opts.ID,
|
||||
MusicBrainzID: &opts.MusicBrainzID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opts.Image != uuid.Nil {
|
||||
l.Debug().Msgf("Updating release with ID %d with image %s", opts.ID, opts.Image)
|
||||
err := qtx.UpdateReleaseImage(ctx, repository.UpdateReleaseImageParams{
|
||||
ID: opts.ID,
|
||||
Image: &opts.Image,
|
||||
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) SaveAlbumAliases(ctx context.Context, id int32, aliases []string, source string) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if id == 0 {
|
||||
return errors.New("album id not specified")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
existing, err := qtx.GetAllReleaseAliases(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, v := range existing {
|
||||
aliases = append(aliases, v.Alias)
|
||||
}
|
||||
utils.Unique(&aliases)
|
||||
for _, alias := range aliases {
|
||||
if strings.TrimSpace(alias) == "" {
|
||||
return errors.New("aliases cannot be blank")
|
||||
}
|
||||
err = qtx.InsertReleaseAlias(ctx, repository.InsertReleaseAliasParams{
|
||||
Alias: strings.TrimSpace(alias),
|
||||
ReleaseID: id,
|
||||
Source: source,
|
||||
IsPrimary: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteAlbum(ctx context.Context, id int32) error {
|
||||
return d.q.DeleteRelease(ctx, id)
|
||||
}
|
||||
func (d *Psql) DeleteAlbumAlias(ctx context.Context, id int32, alias string) error {
|
||||
return d.q.DeleteReleaseAlias(ctx, repository.DeleteReleaseAliasParams{
|
||||
ReleaseID: id,
|
||||
Alias: alias,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Psql) GetAllAlbumAliases(ctx context.Context, id int32) ([]models.Alias, error) {
|
||||
rows, err := d.q.GetAllReleaseAliases(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aliases := make([]models.Alias, len(rows))
|
||||
for i, row := range rows {
|
||||
aliases[i] = models.Alias{
|
||||
ID: id,
|
||||
Alias: row.Alias,
|
||||
Source: row.Source,
|
||||
Primary: row.IsPrimary,
|
||||
}
|
||||
}
|
||||
return aliases, nil
|
||||
}
|
||||
|
||||
func (d *Psql) SetPrimaryAlbumAlias(ctx context.Context, id int32, alias string) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if id == 0 {
|
||||
return errors.New("artist id not specified")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
// get all aliases
|
||||
aliases, err := qtx.GetAllReleaseAliases(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
primary := ""
|
||||
exists := false
|
||||
for _, v := range aliases {
|
||||
if v.Alias == alias {
|
||||
exists = true
|
||||
}
|
||||
if v.IsPrimary {
|
||||
primary = v.Alias
|
||||
}
|
||||
}
|
||||
if primary == alias {
|
||||
// no-op rename
|
||||
return nil
|
||||
}
|
||||
if !exists {
|
||||
return errors.New("alias does not exist")
|
||||
}
|
||||
err = qtx.SetReleaseAliasPrimaryStatus(ctx, repository.SetReleaseAliasPrimaryStatusParams{
|
||||
ReleaseID: id,
|
||||
Alias: alias,
|
||||
IsPrimary: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = qtx.SetReleaseAliasPrimaryStatus(ctx, repository.SetReleaseAliasPrimaryStatusParams{
|
||||
ReleaseID: id,
|
||||
Alias: primary,
|
||||
IsPrimary: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
319
internal/db/psql/album_test.go
Normal file
319
internal/db/psql/album_test.go
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/catalog"
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func truncateTestData(t *testing.T) {
|
||||
err := store.Exec(context.Background(),
|
||||
`TRUNCATE
|
||||
artists,
|
||||
artist_aliases,
|
||||
tracks,
|
||||
artist_tracks,
|
||||
releases,
|
||||
artist_releases,
|
||||
release_aliases,
|
||||
listens
|
||||
RESTART IDENTITY CASCADE`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func testDataForRelease(t *testing.T) {
|
||||
truncateTestData(t)
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001')`)
|
||||
require.NoError(t, err)
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (1, 'ATARASHII GAKKO!', 'MusicBrainz', true)`)
|
||||
require.NoError(t, err)
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000002')`)
|
||||
require.NoError(t, err)
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (2, 'Masayuki Suzuki', 'MusicBrainz', true)`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetAlbum(t *testing.T) {
|
||||
testDataForRelease(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test data
|
||||
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
|
||||
Title: "Test Release Group",
|
||||
ArtistIDs: []int32{1},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GetAlbum by ID
|
||||
result, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rg.ID, result.ID)
|
||||
assert.Equal(t, "Test Release Group", result.Title)
|
||||
|
||||
// Test GetAlbum with insufficient information
|
||||
_, err = store.GetAlbum(ctx, db.GetAlbumOpts{})
|
||||
assert.Error(t, err)
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestSaveAlbum(t *testing.T) {
|
||||
testDataForRelease(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Save release group with artist IDs
|
||||
artistIDs := []int32{1, 2}
|
||||
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
|
||||
Title: "New Release Group",
|
||||
ArtistIDs: artistIDs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify release group was saved
|
||||
assert.Equal(t, "New Release Group", rg.Title)
|
||||
|
||||
// Verify release was created for release group
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM releases_with_title
|
||||
WHERE title = $1 AND id = $2
|
||||
)`, "New Release Group", rg.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected release to exist")
|
||||
|
||||
// Verify artist associations were created for release group
|
||||
for _, aid := range artistIDs {
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM artist_releases
|
||||
WHERE artist_id = $1 AND release_id = $2
|
||||
)`, aid, rg.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected artist association to exist")
|
||||
}
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestUpdateAlbum(t *testing.T) {
|
||||
testDataForRelease(t)
|
||||
ctx := context.Background()
|
||||
|
||||
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
|
||||
Title: "Old Title",
|
||||
ArtistIDs: []int32{1},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newMbzID := uuid.New()
|
||||
imgid := uuid.New()
|
||||
err = store.UpdateAlbum(ctx, db.UpdateAlbumOpts{
|
||||
ID: rg.ID,
|
||||
MusicBrainzID: newMbzID,
|
||||
Image: imgid,
|
||||
ImageSrc: catalog.ImageSourceUserUpload,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newMbzID, *result.MbzID)
|
||||
assert.Equal(t, imgid, *result.Image)
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
func TestAddArtistsToAlbum(t *testing.T) {
|
||||
testDataForRelease(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test album
|
||||
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
|
||||
Title: "Test Album",
|
||||
ArtistIDs: []int32{1},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add additional artists to the album
|
||||
err = store.AddArtistsToAlbum(ctx, db.AddArtistsToAlbumOpts{
|
||||
AlbumID: rg.ID,
|
||||
ArtistIDs: []int32{2},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify artist associations were created
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM artist_releases
|
||||
WHERE artist_id = $1 AND release_id = $2
|
||||
)`, 2, rg.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected artist association to exist")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
func TestSaveAlbumAliases(t *testing.T) {
|
||||
testDataForRelease(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test album
|
||||
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
|
||||
Title: "Test Album",
|
||||
ArtistIDs: []int32{1},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save aliases for the album
|
||||
aliases := []string{"Alias 1", "Alias 2"}
|
||||
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify aliases were saved
|
||||
for _, alias := range aliases {
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM release_aliases
|
||||
WHERE release_id = $1 AND alias = $2
|
||||
)`, rg.ID, alias)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected alias to exist")
|
||||
}
|
||||
|
||||
err = store.SetPrimaryAlbumAlias(ctx, 1, "Alias 1")
|
||||
require.NoError(t, err)
|
||||
album, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Alias 1", album.Title)
|
||||
|
||||
err = store.SetPrimaryAlbumAlias(ctx, 1, "Fake Alias")
|
||||
require.Error(t, err)
|
||||
|
||||
store.SetPrimaryAlbumAlias(ctx, 1, "Album One")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
func TestDeleteAlbum(t *testing.T) {
|
||||
testDataForRelease(t)
|
||||
ctx := context.Background()
|
||||
|
||||
testDataForTopItems(t)
|
||||
|
||||
// Delete the album
|
||||
err := store.DeleteAlbum(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify album was deleted
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM releases
|
||||
WHERE id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected album to be deleted")
|
||||
|
||||
// Verify album's track was deleted
|
||||
exists, err = store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM tracks
|
||||
WHERE id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected album's tracks to be deleted")
|
||||
|
||||
// Verify album's listens was deleted
|
||||
exists, err = store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM listens
|
||||
WHERE track_id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected album's listens to be deleted")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
func TestDeleteAlbumAlias(t *testing.T) {
|
||||
testDataForRelease(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test album
|
||||
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
|
||||
Title: "Test Album",
|
||||
ArtistIDs: []int32{1},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save aliases for the album
|
||||
aliases := []string{"Alias 1", "Alias 2"}
|
||||
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete one alias
|
||||
err = store.DeleteAlbumAlias(ctx, rg.ID, "Alias 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify alias was deleted
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM release_aliases
|
||||
WHERE release_id = $1 AND alias = $2
|
||||
)`, rg.ID, "Alias 1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected alias to be deleted")
|
||||
|
||||
// Verify other alias still exists
|
||||
exists, err = store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM release_aliases
|
||||
WHERE release_id = $1 AND alias = $2
|
||||
)`, rg.ID, "Alias 2")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected alias to still exist")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
func TestGetAllAlbumAliases(t *testing.T) {
|
||||
testDataForRelease(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test album
|
||||
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
|
||||
Title: "Test Album",
|
||||
ArtistIDs: []int32{1},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save aliases for the album
|
||||
aliases := []string{"Alias 1", "Alias 2"}
|
||||
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve all aliases
|
||||
result, err := store.GetAllAlbumAliases(ctx, rg.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, len(aliases)+1) // new + canonical
|
||||
|
||||
for _, alias := range aliases {
|
||||
found := false
|
||||
for _, res := range result {
|
||||
if res.Alias == alias {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected alias to be retrieved")
|
||||
}
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
309
internal/db/psql/artist.go
Normal file
309
internal/db/psql/artist.go
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/gabehf/koito/internal/utils"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
func (d *Psql) GetArtist(ctx context.Context, opts db.GetArtistOpts) (*models.Artist, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
if opts.ID != 0 {
|
||||
l.Debug().Msgf("Fetching artist from DB with id %d", opts.ID)
|
||||
row, err := d.q.GetArtist(ctx, opts.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
|
||||
ListenedAt: time.Unix(0, 0),
|
||||
ListenedAt_2: time.Now(),
|
||||
ArtistID: row.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.Artist{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Name: row.Name,
|
||||
Aliases: row.Aliases,
|
||||
Image: row.Image,
|
||||
ListenCount: count,
|
||||
}, nil
|
||||
} else if opts.MusicBrainzID != uuid.Nil {
|
||||
l.Debug().Msgf("Fetching artist from DB with MusicBrainz ID %s", opts.MusicBrainzID)
|
||||
row, err := d.q.GetArtistByMbzID(ctx, &opts.MusicBrainzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
|
||||
ListenedAt: time.Unix(0, 0),
|
||||
ListenedAt_2: time.Now(),
|
||||
ArtistID: row.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.Artist{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Name: row.Name,
|
||||
Aliases: row.Aliases,
|
||||
Image: row.Image,
|
||||
ListenCount: count,
|
||||
}, nil
|
||||
} else if opts.Name != "" {
|
||||
l.Debug().Msgf("Fetching artist from DB with name '%s'", opts.Name)
|
||||
row, err := d.q.GetArtistByName(ctx, opts.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
|
||||
ListenedAt: time.Unix(0, 0),
|
||||
ListenedAt_2: time.Now(),
|
||||
ArtistID: row.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.Artist{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Name: row.Name,
|
||||
Aliases: row.Aliases,
|
||||
Image: row.Image,
|
||||
ListenCount: count,
|
||||
}, nil
|
||||
} else {
|
||||
return nil, errors.New("insufficient information to get artist")
|
||||
}
|
||||
}
|
||||
|
||||
// Inserts all unique aliases into the DB with specified source
|
||||
func (d *Psql) SaveArtistAliases(ctx context.Context, id int32, aliases []string, source string) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if id == 0 {
|
||||
return errors.New("artist id not specified")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
existing, err := qtx.GetAllArtistAliases(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, v := range existing {
|
||||
aliases = append(aliases, v.Alias)
|
||||
}
|
||||
utils.Unique(&aliases)
|
||||
for _, alias := range aliases {
|
||||
if strings.TrimSpace(alias) == "" {
|
||||
return errors.New("aliases cannot be blank")
|
||||
}
|
||||
err = qtx.InsertArtistAlias(ctx, repository.InsertArtistAliasParams{
|
||||
Alias: strings.TrimSpace(alias),
|
||||
ArtistID: id,
|
||||
Source: source,
|
||||
IsPrimary: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteArtist(ctx context.Context, id int32) error {
|
||||
return d.q.DeleteArtist(ctx, id)
|
||||
}
|
||||
|
||||
// Equivalent to Psql.SaveArtist, then Psql.SaveMbzAliases
|
||||
func (d *Psql) SaveArtist(ctx context.Context, opts db.SaveArtistOpts) (*models.Artist, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
var insertMbzID *uuid.UUID
|
||||
var insertImage *uuid.UUID
|
||||
if opts.MusicBrainzID != uuid.Nil {
|
||||
insertMbzID = &opts.MusicBrainzID
|
||||
}
|
||||
if opts.Image != uuid.Nil {
|
||||
insertImage = &opts.Image
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
opts.Name = strings.TrimSpace(opts.Name)
|
||||
if opts.Name == "" {
|
||||
return nil, errors.New("name must not be blank")
|
||||
}
|
||||
l.Debug().Msgf("Inserting artist '%s' into DB", opts.Name)
|
||||
a, err := qtx.InsertArtist(ctx, repository.InsertArtistParams{
|
||||
MusicBrainzID: insertMbzID,
|
||||
Image: insertImage,
|
||||
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug().Msgf("Inserting canonical alias '%s' into DB for artist with id %d", opts.Name, a.ID)
|
||||
err = qtx.InsertArtistAlias(ctx, repository.InsertArtistAliasParams{
|
||||
ArtistID: a.ID,
|
||||
Alias: opts.Name,
|
||||
Source: "Canonical",
|
||||
IsPrimary: true,
|
||||
})
|
||||
if err != nil {
|
||||
l.Error().Err(err).Msgf("Error inserting canonical alias for artist '%s'", opts.Name)
|
||||
return nil, err
|
||||
}
|
||||
err = tx.Commit(ctx)
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to commit insert artist transaction")
|
||||
return nil, err
|
||||
}
|
||||
artist := &models.Artist{
|
||||
ID: a.ID,
|
||||
Name: opts.Name,
|
||||
Image: a.Image,
|
||||
MbzID: a.MusicBrainzID,
|
||||
Aliases: []string{opts.Name},
|
||||
}
|
||||
if len(opts.Aliases) > 0 {
|
||||
l.Debug().Msgf("Inserting aliases '%v' into DB for artist '%s'", opts.Aliases, opts.Name)
|
||||
err = d.SaveArtistAliases(ctx, a.ID, opts.Aliases, "MusicBrainz")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
artist.Aliases = opts.Aliases
|
||||
}
|
||||
return artist, nil
|
||||
}
|
||||
|
||||
func (d *Psql) UpdateArtist(ctx context.Context, opts db.UpdateArtistOpts) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if opts.ID == 0 {
|
||||
return errors.New("artist id not specified")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
if opts.MusicBrainzID != uuid.Nil {
|
||||
l.Debug().Msgf("Updating artist with id %d with MusicBrainz ID %s", opts.ID, opts.MusicBrainzID)
|
||||
err := qtx.UpdateArtistMbzID(ctx, repository.UpdateArtistMbzIDParams{
|
||||
ID: opts.ID,
|
||||
MusicBrainzID: &opts.MusicBrainzID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opts.Image != uuid.Nil {
|
||||
l.Debug().Msgf("Updating artist with id %d with image %s", opts.ID, opts.Image)
|
||||
err = qtx.UpdateArtistImage(ctx, repository.UpdateArtistImageParams{
|
||||
ID: opts.ID,
|
||||
Image: &opts.Image,
|
||||
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteArtistAlias(ctx context.Context, id int32, alias string) error {
|
||||
return d.q.DeleteArtistAlias(ctx, repository.DeleteArtistAliasParams{
|
||||
ArtistID: id,
|
||||
Alias: alias,
|
||||
})
|
||||
}
|
||||
func (d *Psql) GetAllArtistAliases(ctx context.Context, id int32) ([]models.Alias, error) {
|
||||
rows, err := d.q.GetAllArtistAliases(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aliases := make([]models.Alias, len(rows))
|
||||
for i, row := range rows {
|
||||
aliases[i] = models.Alias{
|
||||
ID: id,
|
||||
Alias: row.Alias,
|
||||
Source: row.Source,
|
||||
Primary: row.IsPrimary,
|
||||
}
|
||||
}
|
||||
return aliases, nil
|
||||
}
|
||||
|
||||
func (d *Psql) SetPrimaryArtistAlias(ctx context.Context, id int32, alias string) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if id == 0 {
|
||||
return errors.New("artist id not specified")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
// get all aliases
|
||||
aliases, err := qtx.GetAllArtistAliases(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
primary := ""
|
||||
exists := false
|
||||
for _, v := range aliases {
|
||||
if v.Alias == alias {
|
||||
exists = true
|
||||
}
|
||||
if v.IsPrimary {
|
||||
primary = v.Alias
|
||||
}
|
||||
}
|
||||
if primary == alias {
|
||||
// no-op rename
|
||||
return nil
|
||||
}
|
||||
if !exists {
|
||||
return errors.New("alias does not exist")
|
||||
}
|
||||
err = qtx.SetArtistAliasPrimaryStatus(ctx, repository.SetArtistAliasPrimaryStatusParams{
|
||||
ArtistID: id,
|
||||
Alias: alias,
|
||||
IsPrimary: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = qtx.SetArtistAliasPrimaryStatus(ctx, repository.SetArtistAliasPrimaryStatusParams{
|
||||
ArtistID: id,
|
||||
Alias: primary,
|
||||
IsPrimary: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
247
internal/db/psql/artist_test.go
Normal file
247
internal/db/psql/artist_test.go
Normal file
|
|
@ -0,0 +1,247 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/catalog"
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetArtist(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mbzId := uuid.MustParse("00000000-0000-0000-0000-000000000001")
|
||||
// Insert test data
|
||||
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
|
||||
Name: "Test Artist",
|
||||
MusicBrainzID: mbzId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test GetArtist by ID
|
||||
result, err := store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, artist.ID, result.ID)
|
||||
assert.Equal(t, "Test Artist", result.Name)
|
||||
|
||||
// Test GetArtist by Name
|
||||
result, err = store.GetArtist(ctx, db.GetArtistOpts{Name: artist.Name})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, artist.ID, result.ID)
|
||||
|
||||
// Test GetArtist by MusicBrainzID
|
||||
result, err = store.GetArtist(ctx, db.GetArtistOpts{MusicBrainzID: mbzId})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, artist.ID, result.ID)
|
||||
|
||||
// Test GetArtist with insufficient information
|
||||
_, err = store.GetArtist(ctx, db.GetArtistOpts{})
|
||||
assert.Error(t, err)
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestSaveAliases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test artist
|
||||
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
|
||||
Name: "Alias Artist",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save aliases
|
||||
aliases := []string{"Alias1", "Alias2"}
|
||||
err = store.SaveArtistAliases(ctx, artist.ID, aliases, "MusicBrainz")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify aliases were saved
|
||||
for _, alias := range aliases {
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM artist_aliases
|
||||
WHERE artist_id = $1 AND alias = $2
|
||||
)`, artist.ID, alias)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected alias to exist")
|
||||
}
|
||||
|
||||
err = store.SetPrimaryArtistAlias(ctx, 1, "Alias1")
|
||||
require.NoError(t, err)
|
||||
artist, err = store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Alias1", artist.Name)
|
||||
|
||||
err = store.SetPrimaryArtistAlias(ctx, 1, "Fake Alias")
|
||||
require.Error(t, err)
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestSaveArtist(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Save artist with aliases
|
||||
aliases := []string{"Alias1", "Alias2"}
|
||||
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
|
||||
Name: "New Artist",
|
||||
Aliases: aliases,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify artist was saved
|
||||
assert.Equal(t, "New Artist", artist.Name)
|
||||
|
||||
// Verify aliases were saved
|
||||
for _, alias := range slices.Concat(aliases, []string{"New Artist"}) {
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM artist_aliases
|
||||
WHERE artist_id = $1 AND alias = $2
|
||||
)`, artist.ID, alias)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected alias '%s' to exist", alias)
|
||||
}
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestUpdateArtist(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test artist
|
||||
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
|
||||
Name: "Old Name",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
imgid := uuid.New()
|
||||
err = store.UpdateArtist(ctx, db.UpdateArtistOpts{
|
||||
ID: artist.ID,
|
||||
Image: imgid,
|
||||
ImageSrc: catalog.ImageSourceUserUpload,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, imgid, *result.Image)
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
func TestGetAllArtistAliases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test artist
|
||||
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
|
||||
Name: "Alias Artist",
|
||||
Aliases: []string{"Alias1", "Alias2"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve all aliases
|
||||
result, err := store.GetAllArtistAliases(ctx, artist.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 3) // Includes canonical alias
|
||||
|
||||
// Verify aliases were retrieved
|
||||
expectedAliases := []string{"Alias Artist", "Alias1", "Alias2"}
|
||||
for _, alias := range expectedAliases {
|
||||
found := false
|
||||
for _, res := range result {
|
||||
if res.Alias == alias {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected alias '%s' to be retrieved", alias)
|
||||
}
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
func TestDeleteArtistAlias(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test artist
|
||||
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
|
||||
Name: "Alias Artist",
|
||||
Aliases: []string{"Alias1", "Alias2"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete one alias
|
||||
err = store.DeleteArtistAlias(ctx, artist.ID, "Alias1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify alias was deleted
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM artist_aliases
|
||||
WHERE artist_id = $1 AND alias = $2
|
||||
)`, artist.ID, "Alias1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected alias to be deleted")
|
||||
|
||||
// Verify other alias still exists
|
||||
exists, err = store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM artist_aliases
|
||||
WHERE artist_id = $1 AND alias = $2
|
||||
)`, artist.ID, "Alias2")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected alias to still exist")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
func TestDeleteArtist(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// set up a lot of test data, 4 artists, 4 albums, 4 tracks, 10 listens
|
||||
testDataForTopItems(t)
|
||||
|
||||
// Delete the artist
|
||||
err := store.DeleteArtist(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify artist was deleted
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM artists
|
||||
WHERE id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected artist to be deleted")
|
||||
|
||||
// Verify artist's release was deleted
|
||||
exists, err = store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM releases
|
||||
WHERE id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected artist's release to be deleted")
|
||||
|
||||
// Verify artist's track was deleted
|
||||
exists, err = store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM tracks
|
||||
WHERE id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected artist's tracks to be deleted")
|
||||
|
||||
// Verify artist's listens was deleted
|
||||
exists, err = store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM listens
|
||||
WHERE track_id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected artist's listens to be deleted")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
70
internal/db/psql/counts.go
Normal file
70
internal/db/psql/counts.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
)
|
||||
|
||||
func (p *Psql) CountListens(ctx context.Context, period db.Period) (int64, error) {
|
||||
t2 := time.Now()
|
||||
t1 := db.StartTimeFromPeriod(period)
|
||||
count, err := p.q.CountListens(ctx, repository.CountListensParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
func (p *Psql) CountTracks(ctx context.Context, period db.Period) (int64, error) {
|
||||
t2 := time.Now()
|
||||
t1 := db.StartTimeFromPeriod(period)
|
||||
count, err := p.q.CountTopTracks(ctx, repository.CountTopTracksParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
func (p *Psql) CountAlbums(ctx context.Context, period db.Period) (int64, error) {
|
||||
t2 := time.Now()
|
||||
t1 := db.StartTimeFromPeriod(period)
|
||||
count, err := p.q.CountTopReleases(ctx, repository.CountTopReleasesParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
func (p *Psql) CountArtists(ctx context.Context, period db.Period) (int64, error) {
|
||||
t2 := time.Now()
|
||||
t1 := db.StartTimeFromPeriod(period)
|
||||
count, err := p.q.CountTopArtists(ctx, repository.CountTopArtistsParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
func (p *Psql) CountTimeListened(ctx context.Context, period db.Period) (int64, error) {
|
||||
t2 := time.Now()
|
||||
t1 := db.StartTimeFromPeriod(period)
|
||||
count, err := p.q.CountTimeListened(ctx, repository.CountTimeListenedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
76
internal/db/psql/counts_test.go
Normal file
76
internal/db/psql/counts_test.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCountListens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testDataForTopItems(t)
|
||||
|
||||
// Test CountListens
|
||||
period := db.PeriodWeek
|
||||
count, err := store.CountListens(ctx, period)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count, "expected listens count to match inserted data")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestCountTracks(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testDataForTopItems(t)
|
||||
|
||||
// Test CountTracks
|
||||
period := db.PeriodMonth
|
||||
count, err := store.CountTracks(ctx, period)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), count, "expected tracks count to match inserted data")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestCountAlbums(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testDataForTopItems(t)
|
||||
|
||||
// Test CountAlbums
|
||||
period := db.PeriodYear
|
||||
count, err := store.CountAlbums(ctx, period)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count, "expected albums count to match inserted data")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestCountArtists(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testDataForTopItems(t)
|
||||
|
||||
// Test CountArtists
|
||||
period := db.PeriodAllTime
|
||||
count, err := store.CountArtists(ctx, period)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(4), count, "expected artists count to match inserted data")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestCountTimeListened(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testDataForTopItems(t)
|
||||
|
||||
// Test CountTimeListened
|
||||
period := db.PeriodMonth
|
||||
count, err := store.CountTimeListened(ctx, period)
|
||||
require.NoError(t, err)
|
||||
// 3 listens in past month, each 100 seconds
|
||||
assert.Equal(t, int64(300), count, "expected total time listened to match inserted data")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
74
internal/db/psql/images.go
Normal file
74
internal/db/psql/images.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func (d *Psql) ImageHasAssociation(ctx context.Context, image uuid.UUID) (bool, error) {
|
||||
_, err := d.q.GetReleaseByImageID(ctx, &image)
|
||||
if err == nil {
|
||||
return true, err
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
_, err = d.q.GetArtistByImage(ctx, &image)
|
||||
if err == nil {
|
||||
return true, err
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (d *Psql) GetImageSource(ctx context.Context, image uuid.UUID) (string, error) {
|
||||
r, err := d.q.GetReleaseByImageID(ctx, &image)
|
||||
if err == nil {
|
||||
return r.ImageSource.String, err
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", err
|
||||
}
|
||||
rr, err := d.q.GetArtistByImage(ctx, &image)
|
||||
if err == nil {
|
||||
return rr.ImageSource.String, err
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", err
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (d *Psql) AlbumsWithoutImages(ctx context.Context, from int32) ([]*models.Album, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
rows, err := d.q.GetReleasesWithoutImages(ctx, repository.GetReleasesWithoutImagesParams{
|
||||
Limit: 20,
|
||||
ID: from,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
albums := make([]*models.Album, len(rows))
|
||||
for i, row := range rows {
|
||||
artists := make([]models.SimpleArtist, 0)
|
||||
err = json.Unmarshal(row.Artists, &artists)
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", row.ID)
|
||||
artists = nil
|
||||
}
|
||||
albums[i] = &models.Album{
|
||||
ID: row.ID,
|
||||
Image: row.Image,
|
||||
Title: row.Title,
|
||||
MbzID: row.MusicBrainzID,
|
||||
VariousArtists: row.VariousArtists,
|
||||
Artists: artists,
|
||||
}
|
||||
}
|
||||
return albums, nil
|
||||
}
|
||||
106
internal/db/psql/images_test.go
Normal file
106
internal/db/psql/images_test.go
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/catalog"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupTestDataForImages(t *testing.T) {
|
||||
truncateTestData(t)
|
||||
|
||||
// Insert artists
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id, image, image_source)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001', '11111111-1111-1111-1111-111111111111', 'User Upload'),
|
||||
('00000000-0000-0000-0000-000000000002', NULL, NULL)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert artist aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (1, 'Artist One', 'Testing', true),
|
||||
(2, 'Artist Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert albums
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO releases (musicbrainz_id, image, image_source)
|
||||
VALUES ('22222222-2222-2222-2222-222222222222', '33333333-3333-3333-3333-333333333333', 'Automatic'),
|
||||
('44444444-4444-4444-4444-444444444444', NULL, NULL)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert release aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
|
||||
VALUES (1, 'Album One', 'Testing', true),
|
||||
(2, 'Album Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Associate albums with artists
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_releases (artist_id, release_id)
|
||||
VALUES (1, 1), (2, 2)`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestImageHasAssociation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForImages(t)
|
||||
|
||||
// Test image with association
|
||||
imageID := uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
hasAssociation, err := store.ImageHasAssociation(ctx, imageID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hasAssociation, "expected image to have an association")
|
||||
|
||||
// Test image without association
|
||||
imageID = uuid.MustParse("55555555-5555-5555-5555-555555555555")
|
||||
hasAssociation, err = store.ImageHasAssociation(ctx, imageID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, hasAssociation, "expected image to have no association")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestGetImageSource(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForImages(t)
|
||||
|
||||
// Test image source for an album
|
||||
imageID := uuid.MustParse("33333333-3333-3333-3333-333333333333")
|
||||
source, err := store.GetImageSource(ctx, imageID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Automatic", source, "expected image source to match")
|
||||
|
||||
// Test image source for an artist
|
||||
imageID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
||||
source, err = store.GetImageSource(ctx, imageID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, catalog.ImageSourceUserUpload, source, "expected image source to match")
|
||||
|
||||
// Test image source for a non-existent image
|
||||
imageID = uuid.MustParse("55555555-5555-5555-5555-555555555555")
|
||||
source, err = store.GetImageSource(ctx, imageID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", source, "expected no image source for non-existent image")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestAlbumsWithoutImages(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForImages(t)
|
||||
|
||||
// Test albums without images
|
||||
albums, err := store.AlbumsWithoutImages(ctx, 0)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, albums, 1, "expected one album without an image")
|
||||
assert.Equal(t, "Album Two", albums[0].Title, "expected album title to match")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
218
internal/db/psql/listen.go
Normal file
218
internal/db/psql/listen.go
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/gabehf/koito/internal/utils"
|
||||
)
|
||||
|
||||
func (d *Psql) GetListensPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Listen], error) {
|
||||
l := logger.FromContext(ctx)
|
||||
offset := (opts.Page - 1) * opts.Limit
|
||||
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.Month == 0 && opts.Year == 0 {
|
||||
// use period, not date range
|
||||
t2 = time.Now()
|
||||
t1 = db.StartTimeFromPeriod(opts.Period)
|
||||
}
|
||||
if opts.Limit == 0 {
|
||||
opts.Limit = DefaultItemsPerPage
|
||||
}
|
||||
var listens []*models.Listen
|
||||
var count int64
|
||||
if opts.TrackID > 0 {
|
||||
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetLastListensFromTrackPaginated(ctx, repository.GetLastListensFromTrackPaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
ID: int32(opts.TrackID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listens = make([]*models.Listen, len(rows))
|
||||
for i, row := range rows {
|
||||
t := &models.Listen{
|
||||
Track: models.Track{
|
||||
Title: row.TrackTitle,
|
||||
ID: row.TrackID,
|
||||
},
|
||||
Time: row.ListenedAt,
|
||||
}
|
||||
err = json.Unmarshal(row.Artists, &t.Track.Artists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listens[i] = t
|
||||
}
|
||||
count, err = d.q.CountListensFromTrack(ctx, repository.CountListensFromTrackParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
TrackID: int32(opts.TrackID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if opts.AlbumID > 0 {
|
||||
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetLastListensFromReleasePaginated(ctx, repository.GetLastListensFromReleasePaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
ReleaseID: int32(opts.AlbumID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listens = make([]*models.Listen, len(rows))
|
||||
for i, row := range rows {
|
||||
t := &models.Listen{
|
||||
Track: models.Track{
|
||||
Title: row.TrackTitle,
|
||||
ID: row.TrackID,
|
||||
},
|
||||
Time: row.ListenedAt,
|
||||
}
|
||||
err = json.Unmarshal(row.Artists, &t.Track.Artists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listens[i] = t
|
||||
}
|
||||
count, err = d.q.CountListensFromRelease(ctx, repository.CountListensFromReleaseParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
ReleaseID: int32(opts.AlbumID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if opts.ArtistID > 0 {
|
||||
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetLastListensFromArtistPaginated(ctx, repository.GetLastListensFromArtistPaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
ArtistID: int32(opts.ArtistID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listens = make([]*models.Listen, len(rows))
|
||||
for i, row := range rows {
|
||||
t := &models.Listen{
|
||||
Track: models.Track{
|
||||
Title: row.TrackTitle,
|
||||
ID: row.TrackID,
|
||||
},
|
||||
Time: row.ListenedAt,
|
||||
}
|
||||
err = json.Unmarshal(row.Artists, &t.Track.Artists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listens[i] = t
|
||||
}
|
||||
count, err = d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
ArtistID: int32(opts.ArtistID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetLastListensPaginated(ctx, repository.GetLastListensPaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listens = make([]*models.Listen, len(rows))
|
||||
for i, row := range rows {
|
||||
t := &models.Listen{
|
||||
Track: models.Track{
|
||||
Title: row.TrackTitle,
|
||||
ID: row.TrackID,
|
||||
},
|
||||
Time: row.ListenedAt,
|
||||
}
|
||||
err = json.Unmarshal(row.Artists, &t.Track.Artists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listens[i] = t
|
||||
}
|
||||
count, err = d.q.CountListens(ctx, repository.CountListensParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug().Msgf("Database responded with %d tracks out of a total %d", len(rows), count)
|
||||
}
|
||||
|
||||
return &db.PaginatedResponse[*models.Listen]{
|
||||
Items: listens,
|
||||
TotalCount: count,
|
||||
ItemsPerPage: int32(opts.Limit),
|
||||
HasNextPage: int64(offset+len(listens)) < count,
|
||||
CurrentPage: int32(opts.Page),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) SaveListen(ctx context.Context, opts db.SaveListenOpts) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if opts.TrackID == 0 {
|
||||
return errors.New("required parameter TrackID missing")
|
||||
}
|
||||
if opts.Time.IsZero() {
|
||||
opts.Time = time.Now()
|
||||
}
|
||||
var client *string
|
||||
if opts.Client != "" {
|
||||
client = &opts.Client
|
||||
}
|
||||
l.Debug().Msgf("Inserting listen for track with id %d at time %v into DB", opts.TrackID, opts.Time)
|
||||
return d.q.InsertListen(ctx, repository.InsertListenParams{
|
||||
TrackID: opts.TrackID,
|
||||
ListenedAt: opts.Time,
|
||||
UserID: opts.UserID,
|
||||
Client: client,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteListen(ctx context.Context, trackId int32, listenedAt time.Time) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if trackId == 0 {
|
||||
return errors.New("required parameter 'trackId' missing")
|
||||
}
|
||||
l.Debug().Msgf("Deleting listen from track %d at time %s from DB", trackId, listenedAt)
|
||||
return d.q.DeleteListen(ctx, repository.DeleteListenParams{
|
||||
TrackID: trackId,
|
||||
ListenedAt: listenedAt,
|
||||
})
|
||||
}
|
||||
109
internal/db/psql/listen_activity.go
Normal file
109
internal/db/psql/listen_activity.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
)
|
||||
|
||||
func (d *Psql) GetListenActivity(ctx context.Context, opts db.ListenActivityOpts) ([]db.ListenActivityItem, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
if opts.Month != 0 && opts.Year == 0 {
|
||||
return nil, errors.New("year must be specified with month")
|
||||
}
|
||||
// Default to range = 12 if not set
|
||||
if opts.Range == 0 {
|
||||
opts.Range = db.DefaultRange
|
||||
}
|
||||
t1, t2 := db.ListenActivityOptsToTimes(opts)
|
||||
var listenActivity []db.ListenActivityItem
|
||||
if opts.AlbumID > 0 {
|
||||
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for release group %d",
|
||||
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.AlbumID)
|
||||
rows, err := d.q.ListenActivityForRelease(ctx, repository.ListenActivityForReleaseParams{
|
||||
Column1: t1,
|
||||
Column2: t2,
|
||||
Column3: stepToInterval(opts.Step),
|
||||
ReleaseID: opts.AlbumID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listenActivity = make([]db.ListenActivityItem, len(rows))
|
||||
for i, row := range rows {
|
||||
t := db.ListenActivityItem{
|
||||
Start: row.BucketStart,
|
||||
Listens: row.ListenCount,
|
||||
}
|
||||
listenActivity[i] = t
|
||||
}
|
||||
l.Debug().Msgf("Database responded with %d steps", len(rows))
|
||||
} else if opts.ArtistID > 0 {
|
||||
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for artist %d",
|
||||
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.ArtistID)
|
||||
rows, err := d.q.ListenActivityForArtist(ctx, repository.ListenActivityForArtistParams{
|
||||
Column1: t1,
|
||||
Column2: t2,
|
||||
Column3: stepToInterval(opts.Step),
|
||||
ArtistID: opts.ArtistID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listenActivity = make([]db.ListenActivityItem, len(rows))
|
||||
for i, row := range rows {
|
||||
t := db.ListenActivityItem{
|
||||
Start: row.BucketStart,
|
||||
Listens: row.ListenCount,
|
||||
}
|
||||
listenActivity[i] = t
|
||||
}
|
||||
l.Debug().Msgf("Database responded with %d steps", len(rows))
|
||||
} else if opts.TrackID > 0 {
|
||||
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for track %d",
|
||||
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.TrackID)
|
||||
rows, err := d.q.ListenActivityForTrack(ctx, repository.ListenActivityForTrackParams{
|
||||
Column1: t1,
|
||||
Column2: t2,
|
||||
Column3: stepToInterval(opts.Step),
|
||||
ID: opts.TrackID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listenActivity = make([]db.ListenActivityItem, len(rows))
|
||||
for i, row := range rows {
|
||||
t := db.ListenActivityItem{
|
||||
Start: row.BucketStart,
|
||||
Listens: row.ListenCount,
|
||||
}
|
||||
listenActivity[i] = t
|
||||
}
|
||||
l.Debug().Msgf("Database responded with %d steps", len(rows))
|
||||
} else {
|
||||
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v",
|
||||
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"))
|
||||
rows, err := d.q.ListenActivity(ctx, repository.ListenActivityParams{
|
||||
Column1: t1,
|
||||
Column2: t2,
|
||||
Column3: stepToInterval(opts.Step),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listenActivity = make([]db.ListenActivityItem, len(rows))
|
||||
for i, row := range rows {
|
||||
t := db.ListenActivityItem{
|
||||
Start: row.BucketStart,
|
||||
Listens: row.ListenCount,
|
||||
}
|
||||
listenActivity[i] = t
|
||||
}
|
||||
l.Debug().Msgf("Database responded with %d steps", len(rows))
|
||||
}
|
||||
|
||||
return listenActivity, nil
|
||||
}
|
||||
211
internal/db/psql/listen_activity_test.go
Normal file
211
internal/db/psql/listen_activity_test.go
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func flattenListenCounts(items []db.ListenActivityItem) []int64 {
|
||||
ret := make([]int64, len(items))
|
||||
for i, v := range items {
|
||||
ret[i] = v.Listens
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestListenActivity(t *testing.T) {
|
||||
truncateTestData(t)
|
||||
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001'),
|
||||
('00000000-0000-0000-0000-000000000002')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Move artist names into artist_aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (1, 'Artist One', 'Testing', true),
|
||||
(2, 'Artist Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert release groups
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO releases (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000011'),
|
||||
('00000000-0000-0000-0000-000000000022')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Move release titles into release_aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
|
||||
VALUES (1, 'Release One', 'Testing', true),
|
||||
(2, 'Release Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert tracks
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO tracks (musicbrainz_id, release_id)
|
||||
VALUES ('11111111-1111-1111-1111-111111111111', 1),
|
||||
('22222222-2222-2222-2222-222222222222', 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Move track titles into track_aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
|
||||
VALUES (1, 'Track One', 'Testing', true),
|
||||
(2, 'Track Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Associate tracks with artists
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_tracks (artist_id, track_id)
|
||||
VALUES (1, 1), (2, 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert listens
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, NOW() - INTERVAL '1 day'),
|
||||
(1, 1, NOW() - INTERVAL '2 days'),
|
||||
(1, 1, NOW() - INTERVAL '1 week 1 day'),
|
||||
(1, 1, NOW() - INTERVAL '1 month 1 day'),
|
||||
(1, 1, NOW() - INTERVAL '1 year 1 day'),
|
||||
(1, 2, NOW() - INTERVAL '1 day'),
|
||||
(1, 2, NOW() - INTERVAL '2 days'),
|
||||
(1, 2, NOW() - INTERVAL '1 week 1 day'),
|
||||
(1, 2, NOW() - INTERVAL '1 month 1 day'),
|
||||
(1, 2, NOW() - INTERVAL '1 year 1 day')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Test for opts.Step = db.StepDay
|
||||
activity, err := store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepDay})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activity, db.DefaultRange)
|
||||
assert.Equal(t, []int64{0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0}, flattenListenCounts(activity))
|
||||
|
||||
// Truncate listens table and insert specific dates for testing opts.Step = db.StepMonth
|
||||
err = store.Exec(context.Background(), `TRUNCATE TABLE listens`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, NOW() - INTERVAL '1 month'),
|
||||
(1, 1, NOW() - INTERVAL '2 months'),
|
||||
(1, 1, NOW() - INTERVAL '3 months'),
|
||||
(1, 2, NOW() - INTERVAL '1 month'),
|
||||
(1, 2, NOW() - INTERVAL '2 months')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepMonth, Range: 8})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activity, 8)
|
||||
assert.Equal(t, []int64{0, 0, 0, 0, 1, 2, 2, 0}, flattenListenCounts(activity))
|
||||
|
||||
// Truncate listens table and insert specific dates for testing opts.Step = db.StepYear
|
||||
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, NOW() - INTERVAL '1 year'),
|
||||
(1, 1, NOW() - INTERVAL '2 years'),
|
||||
(1, 2, NOW() - INTERVAL '1 year'),
|
||||
(1, 2, NOW() - INTERVAL '3 years')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepYear})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activity, db.DefaultRange)
|
||||
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 0}, flattenListenCounts(activity))
|
||||
// Truncate and insert data for a specific month/year
|
||||
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(), `
|
||||
INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, DATE '2024-03-10'),
|
||||
(1, 2, DATE '2024-03-20')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
Step: db.StepDay,
|
||||
Month: 3,
|
||||
Year: 2024,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activity, 31) // number of days in march
|
||||
assert.EqualValues(t, 1, activity[8].Listens)
|
||||
assert.EqualValues(t, 1, activity[18].Listens)
|
||||
|
||||
// Truncate and insert listens associated with two different albums
|
||||
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(), `
|
||||
INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, NOW() - INTERVAL '1 day'), (1, 1, NOW() - INTERVAL '2 days'),
|
||||
(1, 2, NOW() - INTERVAL '1 day')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
Step: db.StepDay,
|
||||
AlbumID: 1, // Track 1 only
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activity, db.DefaultRange)
|
||||
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0}, flattenListenCounts(activity))
|
||||
|
||||
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
Step: db.StepDay,
|
||||
TrackID: 1, // Track 1 only
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activity, db.DefaultRange)
|
||||
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0}, flattenListenCounts(activity))
|
||||
|
||||
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
Step: db.StepDay,
|
||||
ArtistID: 2, // Should only include listens to Track 2
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, activity, db.DefaultRange)
|
||||
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}, flattenListenCounts(activity))
|
||||
|
||||
// month without year is disallowed
|
||||
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
Step: db.StepDay,
|
||||
Month: 5,
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// invalid options
|
||||
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
Year: -10,
|
||||
})
|
||||
require.Error(t, err)
|
||||
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
Year: 2025,
|
||||
Month: -10,
|
||||
})
|
||||
require.Error(t, err)
|
||||
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
Range: -1,
|
||||
})
|
||||
require.Error(t, err)
|
||||
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
AlbumID: -1,
|
||||
})
|
||||
require.Error(t, err)
|
||||
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
|
||||
ArtistID: -1,
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
}
|
||||
219
internal/db/psql/listen_test.go
Normal file
219
internal/db/psql/listen_test.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testDataForListens(t *testing.T) {
|
||||
truncateTestData(t)
|
||||
// Insert artists
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001'),
|
||||
('00000000-0000-0000-0000-000000000002')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert artist aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (1, 'Artist One', 'Testing', true),
|
||||
(2, 'Artist Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert release groups
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO releases (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000011'),
|
||||
('00000000-0000-0000-0000-000000000022')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert release aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
|
||||
VALUES (1, 'Release One', 'Testing', true),
|
||||
(2, 'Release Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert tracks
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO tracks (musicbrainz_id, release_id)
|
||||
VALUES ('11111111-1111-1111-1111-111111111111', 1),
|
||||
('22222222-2222-2222-2222-222222222222', 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert track aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
|
||||
VALUES (1, 'Track One', 'Testing', true),
|
||||
(2, 'Track Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert artist track associations
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_tracks (track_id, artist_id)
|
||||
VALUES (1, 1),
|
||||
(2, 2)`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetListens(t *testing.T) {
|
||||
testDataForTopItems(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test valid
|
||||
resp, err := store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 10)
|
||||
assert.Equal(t, int64(10), resp.TotalCount)
|
||||
require.Len(t, resp.Items[0].Track.Artists, 1)
|
||||
require.Len(t, resp.Items[1].Track.Artists, 1)
|
||||
// ensure tracks are in the right order (time, desc)
|
||||
assert.Equal(t, "Artist Four", resp.Items[0].Track.Artists[0].Name)
|
||||
assert.Equal(t, "Artist Three", resp.Items[1].Track.Artists[0].Name)
|
||||
|
||||
// Test pagination
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
require.Len(t, resp.Items[0].Track.Artists, 1)
|
||||
assert.Equal(t, true, resp.HasNextPage)
|
||||
assert.EqualValues(t, 2, resp.CurrentPage)
|
||||
assert.EqualValues(t, 1, resp.ItemsPerPage)
|
||||
assert.EqualValues(t, 10, resp.TotalCount)
|
||||
assert.Equal(t, "Artist Three", resp.Items[0].Track.Artists[0].Name)
|
||||
|
||||
// Test page out of range
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 10, Page: 10, Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp.Items)
|
||||
assert.False(t, resp.HasNextPage)
|
||||
|
||||
// Test invalid inputs
|
||||
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test specify period
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 0) // empty
|
||||
assert.Equal(t, int64(0), resp.TotalCount)
|
||||
// should default to PeriodDay
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 0) // empty
|
||||
assert.Equal(t, int64(0), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 3)
|
||||
assert.Equal(t, int64(3), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 6)
|
||||
assert.Equal(t, int64(6), resp.TotalCount)
|
||||
|
||||
// Test filter by artists, releases, and tracks
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, ArtistID: 1})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 4)
|
||||
assert.Equal(t, int64(4), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 3)
|
||||
assert.Equal(t, int64(3), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, TrackID: 3})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 2)
|
||||
assert.Equal(t, int64(2), resp.TotalCount)
|
||||
// when both artistID and albumID are specified, artist id is ignored
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2, ArtistID: 1})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 3)
|
||||
assert.Equal(t, int64(3), resp.TotalCount)
|
||||
|
||||
// Test specify dates
|
||||
|
||||
testDataAbsoluteListenTimes(t)
|
||||
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Year: 2023})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 4)
|
||||
assert.Equal(t, int64(4), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 3)
|
||||
assert.Equal(t, int64(3), resp.TotalCount)
|
||||
|
||||
// invalid, year required with month
|
||||
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Month: 10})
|
||||
require.Error(t, err)
|
||||
|
||||
}
|
||||
|
||||
func TestSaveListen(t *testing.T) {
|
||||
testDataForListens(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test SaveListen with valid inputs
|
||||
err := store.SaveListen(ctx, db.SaveListenOpts{
|
||||
TrackID: 1,
|
||||
Time: time.Now(),
|
||||
UserID: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the listen was saved
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM listens
|
||||
WHERE track_id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected listen to exist")
|
||||
|
||||
// Test SaveListen with missing TrackID
|
||||
err = store.SaveListen(ctx, db.SaveListenOpts{
|
||||
TrackID: 0,
|
||||
Time: time.Now(),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDeleteListen(t *testing.T) {
|
||||
testDataForListens(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.Exec(ctx, `
|
||||
INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, to_timestamp(1749464138.0))`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.DeleteListen(ctx, 1, time.Unix(1749464138, 0))
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM listens
|
||||
WHERE track_id = $1
|
||||
)`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "expected listen to be deleted")
|
||||
}
|
||||
109
internal/db/psql/merge.go
Normal file
109
internal/db/psql/merge.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func (d *Psql) MergeTracks(ctx context.Context, fromId, toId int32) error {
|
||||
l := logger.FromContext(ctx)
|
||||
l.Info().Msgf("Merging track %d into track %d", fromId, toId)
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
err = qtx.UpdateTrackIdForListens(ctx, repository.UpdateTrackIdForListensParams{
|
||||
TrackID: fromId,
|
||||
TrackID_2: toId,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = qtx.CleanOrphanedEntries(ctx)
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to clean orphaned entries")
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) MergeAlbums(ctx context.Context, fromId, toId int32) error {
|
||||
l := logger.FromContext(ctx)
|
||||
l.Info().Msgf("Merging album %d into album %d", fromId, toId)
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
err = qtx.UpdateReleaseForAll(ctx, repository.UpdateReleaseForAllParams{
|
||||
ReleaseID: fromId,
|
||||
ReleaseID_2: toId,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = qtx.CleanOrphanedEntries(ctx)
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to clean orphaned entries")
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) MergeArtists(ctx context.Context, fromId, toId int32) error {
|
||||
l := logger.FromContext(ctx)
|
||||
l.Info().Msgf("Merging artist %d into artist %d", fromId, toId)
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
err = qtx.DeleteConflictingArtistTracks(ctx, repository.DeleteConflictingArtistTracksParams{
|
||||
ArtistID: fromId,
|
||||
ArtistID_2: toId,
|
||||
})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to delete conflicting artist tracks")
|
||||
return err
|
||||
}
|
||||
err = qtx.DeleteConflictingArtistReleases(ctx, repository.DeleteConflictingArtistReleasesParams{
|
||||
ArtistID: fromId,
|
||||
ArtistID_2: toId,
|
||||
})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to delete conflicting artist releases")
|
||||
return err
|
||||
}
|
||||
err = qtx.UpdateArtistTracks(ctx, repository.UpdateArtistTracksParams{
|
||||
ArtistID: fromId,
|
||||
ArtistID_2: toId,
|
||||
})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to update artist tracks")
|
||||
return err
|
||||
}
|
||||
err = qtx.UpdateArtistReleases(ctx, repository.UpdateArtistReleasesParams{
|
||||
ArtistID: fromId,
|
||||
ArtistID_2: toId,
|
||||
})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to update artist releases")
|
||||
return err
|
||||
}
|
||||
err = qtx.CleanOrphanedEntries(ctx)
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to clean orphaned entries")
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
124
internal/db/psql/merge_test.go
Normal file
124
internal/db/psql/merge_test.go
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupTestDataForMerge(t *testing.T) {
|
||||
truncateTestData(t)
|
||||
// Insert artists
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001'),
|
||||
('00000000-0000-0000-0000-000000000002')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (1, 'Artist One', 'Testing', true),
|
||||
(2, 'Artist Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert albums
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO releases (musicbrainz_id)
|
||||
VALUES ('11111111-1111-1111-1111-111111111111'),
|
||||
('22222222-2222-2222-2222-222222222222')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
|
||||
VALUES (1, 'Album One', 'Testing', true),
|
||||
(2, 'Album Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert tracks
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO tracks (musicbrainz_id, release_id)
|
||||
VALUES ('33333333-3333-3333-3333-333333333333', 1),
|
||||
('44444444-4444-4444-4444-444444444444', 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
|
||||
VALUES (1, 'Track One', 'Testing', true),
|
||||
(2, 'Track Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Associate artists with albums and tracks
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_releases (artist_id, release_id)
|
||||
VALUES (1, 1), (2, 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_tracks (artist_id, track_id)
|
||||
VALUES (1, 1), (2, 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert listens
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, NOW() - INTERVAL '1 day'),
|
||||
(1, 2, NOW() - INTERVAL '2 days')`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMergeTracks(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForMerge(t)
|
||||
|
||||
// Merge Track 1 into Track 2
|
||||
err := store.MergeTracks(ctx, 1, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify listens are updated
|
||||
var count int
|
||||
count, err = store.Count(ctx, `SELECT COUNT(*) FROM listens WHERE track_id = 2`)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count, "expected all listens to be merged into Track 2")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestMergeAlbums(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForMerge(t)
|
||||
|
||||
// Merge Album 1 into Album 2
|
||||
err := store.MergeAlbums(ctx, 1, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify tracks are updated
|
||||
var count int
|
||||
count, err = store.Count(ctx, `SELECT COUNT(*) FROM tracks WHERE release_id = 2`)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count, "expected all tracks to be merged into Album 2")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestMergeArtists(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForMerge(t)
|
||||
|
||||
// Merge Artist 1 into Artist 2
|
||||
err := store.MergeArtists(ctx, 1, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify artist associations are updated
|
||||
var count int
|
||||
count, err = store.Count(ctx, `SELECT COUNT(*) FROM artist_tracks WHERE artist_id = 2`)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count, "expected all tracks to be associated with Artist 2")
|
||||
|
||||
count, err = store.Count(ctx, `SELECT COUNT(*) FROM artist_releases WHERE artist_id = 2`)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count, "expected all releases to be associated with Artist 2")
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
119
internal/db/psql/psql.go
Normal file
119
internal/db/psql/psql.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
// package psql implements the db.DB interface using psx and a sql generated repository
|
||||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/cfg"
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/pressly/goose/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultItemsPerPage = 20
|
||||
)
|
||||
|
||||
type Psql struct {
|
||||
q *repository.Queries
|
||||
conn *pgxpool.Pool
|
||||
}
|
||||
|
||||
func New() (*Psql, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config, err := pgxpool.ParseConfig(cfg.DatabaseUrl())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse pgx config: %w", err)
|
||||
}
|
||||
|
||||
config.ConnConfig.ConnectTimeout = 15 * time.Second
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create pgx pool: %w", err)
|
||||
}
|
||||
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("database not reachable: %w", err)
|
||||
}
|
||||
|
||||
sqlDB, err := sql.Open("pgx", cfg.DatabaseUrl())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open db for migrations: %w", err)
|
||||
}
|
||||
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get caller info")
|
||||
}
|
||||
migrationsPath := filepath.Join(filepath.Dir(filename), "..", "..", "..", "db", "migrations")
|
||||
|
||||
if err := goose.Up(sqlDB, migrationsPath); err != nil {
|
||||
return nil, fmt.Errorf("goose failed: %w", err)
|
||||
}
|
||||
_ = sqlDB.Close()
|
||||
|
||||
return &Psql{
|
||||
q: repository.New(pool),
|
||||
conn: pool,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Not part of the DB interface this package implements. Only used for testing.
|
||||
func (d *Psql) Exec(ctx context.Context, query string, args ...any) error {
|
||||
_, err := d.conn.Exec(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Not part of the DB interface this package implements. Only used for testing.
|
||||
func (d *Psql) RowExists(ctx context.Context, query string, args ...any) (bool, error) {
|
||||
var exists bool
|
||||
err := d.conn.QueryRow(ctx, query, args...).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (p *Psql) Count(ctx context.Context, query string, args ...any) (count int, err error) {
|
||||
err = p.conn.QueryRow(ctx, query, args...).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
// Exposes p.conn.QueryRow. Only used for testing. Not part of the DB interface this package implements.
|
||||
func (p *Psql) QueryRow(ctx context.Context, query string, args ...any) pgx.Row {
|
||||
return p.conn.QueryRow(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (d *Psql) Close(ctx context.Context) {
|
||||
d.conn.Close()
|
||||
}
|
||||
|
||||
func (d *Psql) Ping(ctx context.Context) error {
|
||||
return d.conn.Ping(ctx)
|
||||
}
|
||||
|
||||
func stepToInterval(p db.StepInterval) pgtype.Interval {
|
||||
var interval pgtype.Interval
|
||||
switch p {
|
||||
case db.StepDay:
|
||||
interval.Days = 1
|
||||
case db.StepWeek:
|
||||
interval.Days = 7
|
||||
case db.StepMonth:
|
||||
interval.Months = 1
|
||||
case db.StepYear:
|
||||
interval.Months = 12
|
||||
}
|
||||
interval.Valid = true
|
||||
return interval
|
||||
}
|
||||
186
internal/db/psql/psql_test.go
Normal file
186
internal/db/psql/psql_test.go
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/cfg"
|
||||
"github.com/gabehf/koito/internal/db/psql"
|
||||
_ "github.com/gabehf/koito/testing_init"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var store *psql.Psql
|
||||
|
||||
func getTestGetenv(resource *dockertest.Resource) func(string) string {
|
||||
return func(env string) string {
|
||||
switch env {
|
||||
case cfg.DATABASE_URL_ENV:
|
||||
return fmt.Sprintf("postgres://postgres:secret@localhost:%s", resource.GetPort("5432/tcp"))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not construct pool: %s", err)
|
||||
}
|
||||
|
||||
// uses pool to try to connect to Docker
|
||||
err = pool.Client.Ping()
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to Docker: %s", err)
|
||||
}
|
||||
|
||||
// pulls an image, creates a container based on it and runs it
|
||||
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret"})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start resource: %s", err)
|
||||
}
|
||||
|
||||
err = cfg.Load(getTestGetenv(resource))
|
||||
if err != nil {
|
||||
log.Fatalf("Could not load cfg: %s", err)
|
||||
}
|
||||
|
||||
// exponential backoff-retry, because the application in the container might not be ready to accept connections yet
|
||||
if err := pool.Retry(func() error {
|
||||
var err error
|
||||
store, err = psql.New()
|
||||
if err != nil {
|
||||
log.Println("Failed to connect to test database, retrying...")
|
||||
return err
|
||||
}
|
||||
return store.Ping(context.Background())
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to database: %s", err)
|
||||
}
|
||||
|
||||
// as of go1.15 testing.M returns the exit code of m.Run(), so it is safe to use defer here
|
||||
defer func() {
|
||||
if err := pool.Purge(resource); err != nil {
|
||||
log.Fatalf("Could not purge resource: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// insert a user into the db with id 1 to use for tests
|
||||
err = store.Exec(context.Background(), `INSERT INTO users (username, password) VALUES ('test', DECODE('abc123', 'hex'))`)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to insert test user: %v", err)
|
||||
}
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func testDataForTopItems(t *testing.T) {
|
||||
truncateTestData(t)
|
||||
|
||||
// artist 1 has most listens older than 1 year
|
||||
// artist 2 has most listens older than 1 month
|
||||
// artist 3 has most listens older than 1 week
|
||||
// artist 4 has least listens
|
||||
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001'),
|
||||
('00000000-0000-0000-0000-000000000002'),
|
||||
('00000000-0000-0000-0000-000000000003'),
|
||||
('00000000-0000-0000-0000-000000000004')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (1, 'Artist One', 'Testing', true),
|
||||
(2, 'Artist Two', 'Testing', true),
|
||||
(3, 'Artist Three', 'Testing', true),
|
||||
(4, 'Artist Four', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert release groups
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO releases (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000011'),
|
||||
('00000000-0000-0000-0000-000000000022'),
|
||||
('00000000-0000-0000-0000-000000000033'),
|
||||
('00000000-0000-0000-0000-000000000044')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
|
||||
VALUES (1, 'Release One', 'Testing', true),
|
||||
(2, 'Release Two', 'Testing', true),
|
||||
(3, 'Release Three', 'Testing', true),
|
||||
(4, 'Release Four', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert release groups
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_releases (release_id, artist_id)
|
||||
VALUES (1, 1), (2, 2), (3, 3), (4, 4)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert tracks
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO tracks (musicbrainz_id, release_id, duration)
|
||||
VALUES ('11111111-1111-1111-1111-111111111111', 1, 100),
|
||||
('22222222-2222-2222-2222-222222222222', 2, 100),
|
||||
('33333333-3333-3333-3333-333333333333', 3, 100),
|
||||
('44444444-4444-4444-4444-444444444444', 4, 100)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
|
||||
VALUES (1, 'Track One', 'Testing', true),
|
||||
(2, 'Track Two', 'Testing', true),
|
||||
(3, 'Track Three', 'Testing', true),
|
||||
(4, 'Track Four', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Associate tracks with artists
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_tracks (artist_id, track_id)
|
||||
VALUES (1, 1), (2, 2), (3, 3), (4, 4)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert listens
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, NOW() - INTERVAL '2 years 1 day'),
|
||||
(1, 1, NOW() - INTERVAL '2 years 2 days'),
|
||||
(1, 1, NOW() - INTERVAL '2 years 3 days'),
|
||||
(1, 1, NOW() - INTERVAL '2 years 4 days'),
|
||||
(1, 2, NOW() - INTERVAL '2 months 1 day'),
|
||||
(1, 2, NOW() - INTERVAL '2 months 2 days'),
|
||||
(1, 2, NOW() - INTERVAL '2 months 3 days'),
|
||||
(1, 3, NOW() - INTERVAL '2 weeks'),
|
||||
(1, 3, NOW() - INTERVAL '2 weeks 1 day'),
|
||||
(1, 4, NOW() - INTERVAL '2 days')`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func testDataAbsoluteListenTimes(t *testing.T) {
|
||||
err := store.Exec(context.Background(),
|
||||
`TRUNCATE listens`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO listens (user_id, track_id, listened_at)
|
||||
VALUES (1, 1, '2023-06-22 19:11:25-07'),
|
||||
(1, 1, '2023-06-22 19:12:25-07'),
|
||||
(1, 1, '2023-06-22 19:13:25-07'),
|
||||
(1, 1, '2023-06-22 19:14:25-07'),
|
||||
(1, 2, '2024-06-22 19:15:25-07'),
|
||||
(1, 2, '2024-06-22 19:16:25-07'),
|
||||
(1, 2, '2024-06-22 19:17:25-07'),
|
||||
(1, 3, '2024-10-02 19:18:25-07'),
|
||||
(1, 3, '2024-10-02 19:19:25-07'),
|
||||
(1, 4, '2025-05-16 19:20:25-07')`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
151
internal/db/psql/search.go
Normal file
151
internal/db/psql/search.go
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const searchItemLimit = 5
|
||||
const substringSearchLength = 6
|
||||
|
||||
func (d *Psql) SearchArtists(ctx context.Context, q string) ([]*models.Artist, error) {
|
||||
if len(q) < substringSearchLength {
|
||||
rows, err := d.q.SearchArtistsBySubstring(ctx, repository.SearchArtistsBySubstringParams{
|
||||
Column1: pgtype.Text{String: q, Valid: true},
|
||||
Limit: searchItemLimit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make([]*models.Artist, len(rows))
|
||||
for i, row := range rows {
|
||||
ret[i] = &models.Artist{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Name: row.Name,
|
||||
Image: row.Image,
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
} else {
|
||||
rows, err := d.q.SearchArtists(ctx, repository.SearchArtistsParams{
|
||||
Similarity: q,
|
||||
Limit: searchItemLimit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make([]*models.Artist, len(rows))
|
||||
for i, row := range rows {
|
||||
ret[i] = &models.Artist{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Name: row.Name,
|
||||
Image: row.Image,
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Psql) SearchAlbums(ctx context.Context, q string) ([]*models.Album, error) {
|
||||
if len(q) < substringSearchLength {
|
||||
rows, err := d.q.SearchReleasesBySubstring(ctx, repository.SearchReleasesBySubstringParams{
|
||||
Column1: pgtype.Text{String: q, Valid: true},
|
||||
Limit: searchItemLimit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make([]*models.Album, len(rows))
|
||||
for i, row := range rows {
|
||||
ret[i] = &models.Album{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Title: row.Title,
|
||||
VariousArtists: row.VariousArtists,
|
||||
Image: row.Image,
|
||||
}
|
||||
err = json.Unmarshal(row.Artists, &ret[i].Artists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
} else {
|
||||
rows, err := d.q.SearchReleases(ctx, repository.SearchReleasesParams{
|
||||
Similarity: q,
|
||||
Limit: searchItemLimit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make([]*models.Album, len(rows))
|
||||
for i, row := range rows {
|
||||
ret[i] = &models.Album{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Title: row.Title,
|
||||
VariousArtists: row.VariousArtists,
|
||||
Image: row.Image,
|
||||
}
|
||||
err = json.Unmarshal(row.Artists, &ret[i].Artists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Psql) SearchTracks(ctx context.Context, q string) ([]*models.Track, error) {
|
||||
if len(q) < substringSearchLength {
|
||||
rows, err := d.q.SearchTracksBySubstring(ctx, repository.SearchTracksBySubstringParams{
|
||||
Column1: pgtype.Text{String: q, Valid: true},
|
||||
Limit: searchItemLimit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make([]*models.Track, len(rows))
|
||||
for i, row := range rows {
|
||||
ret[i] = &models.Track{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Title: row.Title,
|
||||
Image: row.Image,
|
||||
}
|
||||
err = json.Unmarshal(row.Artists, &ret[i].Artists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
} else {
|
||||
rows, err := d.q.SearchTracks(ctx, repository.SearchTracksParams{
|
||||
Similarity: q,
|
||||
Limit: searchItemLimit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make([]*models.Track, len(rows))
|
||||
for i, row := range rows {
|
||||
ret[i] = &models.Track{
|
||||
ID: row.ID,
|
||||
MbzID: row.MusicBrainzID,
|
||||
Title: row.Title,
|
||||
Image: row.Image,
|
||||
}
|
||||
err = json.Unmarshal(row.Artists, &ret[i].Artists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
116
internal/db/psql/search_test.go
Normal file
116
internal/db/psql/search_test.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupTestDataForSearch(t *testing.T) {
|
||||
truncateTestData(t)
|
||||
|
||||
// Insert artists
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001'),
|
||||
('00000000-0000-0000-0000-000000000002')`)
|
||||
require.NoError(t, err)
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (1, 'Artist One With A Long Name', 'Testing', true),
|
||||
(2, 'Artist Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert albums
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO releases (musicbrainz_id, various_artists)
|
||||
VALUES ('11111111-1111-1111-1111-111111111111', false),
|
||||
('22222222-2222-2222-2222-222222222222', true)`)
|
||||
require.NoError(t, err)
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
|
||||
VALUES (1, 'Album One With A Long Name', 'Testing', true),
|
||||
(2, 'Album Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert tracks
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO tracks (musicbrainz_id, release_id)
|
||||
VALUES ('33333333-3333-3333-3333-333333333333', 1),
|
||||
('44444444-4444-4444-4444-444444444444', 2)`)
|
||||
require.NoError(t, err)
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
|
||||
VALUES (1, 'Track One With A Long Name', 'Testing', true),
|
||||
(2, 'Track Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Associate artists with albums and tracks
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_releases (artist_id, release_id)
|
||||
VALUES (1, 1), (2, 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_tracks (artist_id, track_id)
|
||||
VALUES (1, 1), (2, 2)`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSearchArtists(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForSearch(t)
|
||||
|
||||
// Search for "Artist One With A Long Name"
|
||||
results, err := store.SearchArtists(ctx, "Artist One With A Long Name")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
assert.Equal(t, "Artist One With A Long Name", results[0].Name)
|
||||
|
||||
// Search for substring "Artist"
|
||||
results, err = store.SearchArtists(ctx, "Arti")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 2)
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestSearchAlbums(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForSearch(t)
|
||||
|
||||
// Search for "Album One With A Long Name"
|
||||
results, err := store.SearchAlbums(ctx, "Album One With A Long Name")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
assert.Equal(t, "Album One With A Long Name", results[0].Title)
|
||||
|
||||
// Search for substring "Album"
|
||||
results, err = store.SearchAlbums(ctx, "Albu")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 2)
|
||||
assert.NotNil(t, results[0].Artists)
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
|
||||
func TestSearchTracks(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForSearch(t)
|
||||
|
||||
// Search for "Track One With A Long Name"
|
||||
results, err := store.SearchTracks(ctx, "Track One With A Long Name")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
assert.Equal(t, "Track One With A Long Name", results[0].Title)
|
||||
|
||||
// Search for substring "Track"
|
||||
results, err = store.SearchTracks(ctx, "Trac")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 2)
|
||||
assert.NotNil(t, results[0].Artists)
|
||||
|
||||
truncateTestData(t)
|
||||
}
|
||||
59
internal/db/psql/sessions.go
Normal file
59
internal/db/psql/sessions.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func (d *Psql) SaveSession(ctx context.Context, userID int32, expiresAt time.Time, persistent bool) (*models.Session, error) {
|
||||
session, err := d.q.InsertSession(ctx, repository.InsertSessionParams{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
ExpiresAt: expiresAt,
|
||||
Persistent: persistent,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.Session{
|
||||
ID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
ExpiresAt: session.ExpiresAt,
|
||||
Persistent: session.Persistent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) RefreshSession(ctx context.Context, sessionId uuid.UUID, expiresAt time.Time) error {
|
||||
return d.q.UpdateSessionExpiry(ctx, repository.UpdateSessionExpiryParams{
|
||||
ID: sessionId,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteSession(ctx context.Context, sessionId uuid.UUID) error {
|
||||
return d.q.DeleteSession(ctx, sessionId)
|
||||
}
|
||||
|
||||
// Returns nil, nil when no database entries are found
|
||||
func (d *Psql) GetUserBySession(ctx context.Context, sessionId uuid.UUID) (*models.User, error) {
|
||||
row, err := d.q.GetUserBySession(ctx, sessionId)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.User{
|
||||
ID: row.ID,
|
||||
Username: row.Username,
|
||||
Password: row.Password,
|
||||
Role: models.UserRole(row.Role),
|
||||
}, nil
|
||||
}
|
||||
101
internal/db/psql/sessions_test.go
Normal file
101
internal/db/psql/sessions_test.go
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func truncateTestDataForSessions(t *testing.T) {
|
||||
err := store.Exec(context.Background(),
|
||||
`TRUNCATE
|
||||
sessions
|
||||
RESTART IDENTITY CASCADE`,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
func TestSaveSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Save a session for the user
|
||||
expiresAt := time.Now().Add(24 * time.Hour).UTC()
|
||||
session, err := store.SaveSession(ctx, 1, expiresAt, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, session)
|
||||
assert.Equal(t, int32(1), session.UserID)
|
||||
assert.Equal(t, true, session.Persistent)
|
||||
assert.WithinDuration(t, expiresAt, session.ExpiresAt, time.Second)
|
||||
|
||||
truncateTestDataForSessions(t)
|
||||
}
|
||||
|
||||
func TestRefreshSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Save a session first
|
||||
expiresAt := time.Now().Add(-1 * time.Minute)
|
||||
session, err := store.SaveSession(ctx, 1, expiresAt, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh the session expiry
|
||||
newExpiresAt := time.Now().Add(48 * time.Hour)
|
||||
err = store.RefreshSession(ctx, session.ID, newExpiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Can only retrieve a session with an expiresAt > time.Now()
|
||||
_, err = store.GetUserBySession(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
truncateTestDataForSessions(t)
|
||||
}
|
||||
|
||||
func TestDeleteSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Save a session first
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
session, err := store.SaveSession(ctx, 1, expiresAt, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the session
|
||||
err = store.DeleteSession(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the session was deleted
|
||||
var count int
|
||||
count, err = store.Count(ctx, `SELECT COUNT(*) FROM sessions WHERE id = $1`, session.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
|
||||
truncateTestDataForSessions(t)
|
||||
}
|
||||
|
||||
func TestGetUserBySession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Save a session first
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
session, err := store.SaveSession(ctx, 1, expiresAt, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the user by session
|
||||
user, err := store.GetUserBySession(ctx, session.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, int32(1), user.ID)
|
||||
assert.Equal(t, "test", user.Username)
|
||||
assert.Equal(t, []uint8([]byte{0xab, 0xc1, 0x23}), user.Password)
|
||||
assert.Equal(t, "user", string(user.Role))
|
||||
|
||||
// Test for a non-existent session
|
||||
nonExistentSessionID := uuid.New()
|
||||
user, err = store.GetUserBySession(ctx, nonExistentSessionID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, user)
|
||||
|
||||
truncateTestDataForSessions(t)
|
||||
}
|
||||
119
internal/db/psql/top_albums.go
Normal file
119
internal/db/psql/top_albums.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/gabehf/koito/internal/utils"
|
||||
)
|
||||
|
||||
func (d *Psql) GetTopAlbumsPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Album], error) {
|
||||
l := logger.FromContext(ctx)
|
||||
offset := (opts.Page - 1) * opts.Limit
|
||||
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.Month == 0 && opts.Year == 0 {
|
||||
// use period, not date range
|
||||
t2 = time.Now()
|
||||
t1 = db.StartTimeFromPeriod(opts.Period)
|
||||
}
|
||||
if opts.Limit == 0 {
|
||||
opts.Limit = DefaultItemsPerPage
|
||||
}
|
||||
|
||||
var rgs []*models.Album
|
||||
var count int64
|
||||
|
||||
if opts.ArtistID != 0 {
|
||||
l.Debug().Msgf("Fetching top %d albums from artist id %d with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.ArtistID, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
|
||||
rows, err := d.q.GetTopReleasesFromArtist(ctx, repository.GetTopReleasesFromArtistParams{
|
||||
ArtistID: int32(opts.ArtistID),
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rgs = make([]*models.Album, len(rows))
|
||||
l.Debug().Msgf("Database responded with %d items", len(rows))
|
||||
for i, v := range rows {
|
||||
artists := make([]models.SimpleArtist, 0)
|
||||
err = json.Unmarshal(v.Artists, &artists)
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", v.ID)
|
||||
artists = nil
|
||||
}
|
||||
rgs[i] = &models.Album{
|
||||
ID: v.ID,
|
||||
MbzID: v.MusicBrainzID,
|
||||
Title: v.Title,
|
||||
Image: v.Image,
|
||||
Artists: artists,
|
||||
VariousArtists: v.VariousArtists,
|
||||
ListenCount: v.ListenCount,
|
||||
}
|
||||
}
|
||||
count, err = d.q.CountReleasesFromArtist(ctx, int32(opts.ArtistID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
l.Debug().Msgf("Fetching top %d albums with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetTopReleasesPaginated(ctx, repository.GetTopReleasesPaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rgs = make([]*models.Album, len(rows))
|
||||
l.Debug().Msgf("Database responded with %d items", len(rows))
|
||||
for i, row := range rows {
|
||||
artists := make([]models.SimpleArtist, 0)
|
||||
err = json.Unmarshal(row.Artists, &artists)
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", row.ID)
|
||||
artists = nil
|
||||
}
|
||||
t := &models.Album{
|
||||
Title: row.Title,
|
||||
MbzID: row.MusicBrainzID,
|
||||
ID: row.ID,
|
||||
Image: row.Image,
|
||||
Artists: artists,
|
||||
VariousArtists: row.VariousArtists,
|
||||
ListenCount: row.ListenCount,
|
||||
}
|
||||
rgs[i] = t
|
||||
}
|
||||
count, err = d.q.CountTopReleases(ctx, repository.CountTopReleasesParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug().Msgf("Database responded with %d albums out of a total %d", len(rows), count)
|
||||
}
|
||||
return &db.PaginatedResponse[*models.Album]{
|
||||
Items: rgs,
|
||||
TotalCount: count,
|
||||
ItemsPerPage: int32(opts.Limit),
|
||||
HasNextPage: int64(offset+len(rgs)) < count,
|
||||
CurrentPage: int32(opts.Page),
|
||||
}, nil
|
||||
}
|
||||
103
internal/db/psql/top_albums_test.go
Normal file
103
internal/db/psql/top_albums_test.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetTopAlbumsPaginated(t *testing.T) {
|
||||
testDataForTopItems(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test valid
|
||||
resp, err := store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 4)
|
||||
assert.Equal(t, int64(4), resp.TotalCount)
|
||||
assert.Equal(t, "Release One", resp.Items[0].Title)
|
||||
assert.Equal(t, "Release Two", resp.Items[1].Title)
|
||||
assert.Equal(t, "Release Three", resp.Items[2].Title)
|
||||
assert.Equal(t, "Release Four", resp.Items[3].Title)
|
||||
|
||||
// Test pagination
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, "Release Two", resp.Items[0].Title)
|
||||
|
||||
// Test page out of range
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp.Items)
|
||||
assert.False(t, resp.HasNextPage)
|
||||
|
||||
// Test invalid inputs
|
||||
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test specify period
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 0) // empty
|
||||
assert.Equal(t, int64(0), resp.TotalCount)
|
||||
// should default to PeriodDay
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 0) // empty
|
||||
assert.Equal(t, int64(0), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Release Four", resp.Items[0].Title)
|
||||
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 2)
|
||||
assert.Equal(t, int64(2), resp.TotalCount)
|
||||
assert.Equal(t, "Release Three", resp.Items[0].Title)
|
||||
assert.Equal(t, "Release Four", resp.Items[1].Title)
|
||||
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 3)
|
||||
assert.Equal(t, int64(3), resp.TotalCount)
|
||||
assert.Equal(t, "Release Two", resp.Items[0].Title)
|
||||
assert.Equal(t, "Release Three", resp.Items[1].Title)
|
||||
assert.Equal(t, "Release Four", resp.Items[2].Title)
|
||||
|
||||
// test specific artist
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear, ArtistID: 2})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Release Two", resp.Items[0].Title)
|
||||
|
||||
// Test specify dates
|
||||
|
||||
testDataAbsoluteListenTimes(t)
|
||||
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Year: 2023})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Release One", resp.Items[0].Title)
|
||||
|
||||
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Release Two", resp.Items[0].Title)
|
||||
|
||||
// invalid, year required with month
|
||||
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Month: 10})
|
||||
require.Error(t, err)
|
||||
}
|
||||
67
internal/db/psql/top_artists.go
Normal file
67
internal/db/psql/top_artists.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/gabehf/koito/internal/utils"
|
||||
)
|
||||
|
||||
func (d *Psql) GetTopArtistsPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Artist], error) {
|
||||
l := logger.FromContext(ctx)
|
||||
offset := (opts.Page - 1) * opts.Limit
|
||||
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.Month == 0 && opts.Year == 0 {
|
||||
// use period, not date range
|
||||
t2 = time.Now()
|
||||
t1 = db.StartTimeFromPeriod(opts.Period)
|
||||
}
|
||||
if opts.Limit == 0 {
|
||||
opts.Limit = DefaultItemsPerPage
|
||||
}
|
||||
l.Debug().Msgf("Fetching top %d artists with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetTopArtistsPaginated(ctx, repository.GetTopArtistsPaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rgs := make([]*models.Artist, len(rows))
|
||||
for i, row := range rows {
|
||||
t := &models.Artist{
|
||||
Name: row.Name,
|
||||
MbzID: row.MusicBrainzID,
|
||||
ID: row.ID,
|
||||
Image: row.Image,
|
||||
ListenCount: row.ListenCount,
|
||||
}
|
||||
rgs[i] = t
|
||||
}
|
||||
count, err := d.q.CountTopArtists(ctx, repository.CountTopArtistsParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug().Msgf("Database responded with %d artists out of a total %d", len(rows), count)
|
||||
|
||||
return &db.PaginatedResponse[*models.Artist]{
|
||||
Items: rgs,
|
||||
TotalCount: count,
|
||||
ItemsPerPage: int32(opts.Limit),
|
||||
HasNextPage: int64(offset+len(rgs)) < count,
|
||||
CurrentPage: int32(opts.Page),
|
||||
}, nil
|
||||
}
|
||||
96
internal/db/psql/top_artists_test.go
Normal file
96
internal/db/psql/top_artists_test.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetTopArtistsPaginated(t *testing.T) {
|
||||
testDataForTopItems(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test valid
|
||||
resp, err := store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 4)
|
||||
assert.Equal(t, int64(4), resp.TotalCount)
|
||||
assert.Equal(t, "Artist One", resp.Items[0].Name)
|
||||
assert.Equal(t, "Artist Two", resp.Items[1].Name)
|
||||
assert.Equal(t, "Artist Three", resp.Items[2].Name)
|
||||
assert.Equal(t, "Artist Four", resp.Items[3].Name)
|
||||
|
||||
// Test pagination
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, "Artist Two", resp.Items[0].Name)
|
||||
|
||||
// Test page out of range
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp.Items)
|
||||
assert.False(t, resp.HasNextPage)
|
||||
|
||||
// Test invalid inputs
|
||||
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test specify period
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 0) // empty
|
||||
assert.Equal(t, int64(0), resp.TotalCount)
|
||||
// should default to PeriodDay
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 0) // empty
|
||||
assert.Equal(t, int64(0), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Artist Four", resp.Items[0].Name)
|
||||
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 2)
|
||||
assert.Equal(t, int64(2), resp.TotalCount)
|
||||
assert.Equal(t, "Artist Three", resp.Items[0].Name)
|
||||
assert.Equal(t, "Artist Four", resp.Items[1].Name)
|
||||
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 3)
|
||||
assert.Equal(t, int64(3), resp.TotalCount)
|
||||
assert.Equal(t, "Artist Two", resp.Items[0].Name)
|
||||
assert.Equal(t, "Artist Three", resp.Items[1].Name)
|
||||
assert.Equal(t, "Artist Four", resp.Items[2].Name)
|
||||
|
||||
// Test specify dates
|
||||
|
||||
testDataAbsoluteListenTimes(t)
|
||||
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Year: 2023})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Artist One", resp.Items[0].Name)
|
||||
|
||||
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Artist Two", resp.Items[0].Name)
|
||||
|
||||
// invalid, year required with month
|
||||
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Month: 10})
|
||||
require.Error(t, err)
|
||||
}
|
||||
160
internal/db/psql/top_tracks.go
Normal file
160
internal/db/psql/top_tracks.go
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/gabehf/koito/internal/utils"
|
||||
)
|
||||
|
||||
func (d *Psql) GetTopTracksPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Track], error) {
|
||||
l := logger.FromContext(ctx)
|
||||
offset := (opts.Page - 1) * opts.Limit
|
||||
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.Month == 0 && opts.Year == 0 {
|
||||
// use period, not date range
|
||||
t2 = time.Now()
|
||||
t1 = db.StartTimeFromPeriod(opts.Period)
|
||||
}
|
||||
if opts.Limit == 0 {
|
||||
opts.Limit = DefaultItemsPerPage
|
||||
}
|
||||
var tracks []*models.Track
|
||||
var count int64
|
||||
if opts.AlbumID > 0 {
|
||||
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetTopTracksInReleasePaginated(ctx, repository.GetTopTracksInReleasePaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
ReleaseID: int32(opts.AlbumID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tracks = make([]*models.Track, len(rows))
|
||||
for i, row := range rows {
|
||||
artists := make([]models.SimpleArtist, 0)
|
||||
err = json.Unmarshal(row.Artists, &artists)
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
|
||||
artists = nil
|
||||
}
|
||||
t := &models.Track{
|
||||
Title: row.Title,
|
||||
MbzID: row.MusicBrainzID,
|
||||
ID: row.ID,
|
||||
ListenCount: row.ListenCount,
|
||||
Image: row.Image,
|
||||
AlbumID: row.ReleaseID,
|
||||
Artists: artists,
|
||||
}
|
||||
tracks[i] = t
|
||||
}
|
||||
count, err = d.q.CountTopTracksByRelease(ctx, repository.CountTopTracksByReleaseParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
ReleaseID: int32(opts.AlbumID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if opts.ArtistID > 0 {
|
||||
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetTopTracksByArtistPaginated(ctx, repository.GetTopTracksByArtistPaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
ArtistID: int32(opts.ArtistID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tracks = make([]*models.Track, len(rows))
|
||||
for i, row := range rows {
|
||||
artists := make([]models.SimpleArtist, 0)
|
||||
err = json.Unmarshal(row.Artists, &artists)
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
|
||||
artists = nil
|
||||
}
|
||||
t := &models.Track{
|
||||
Title: row.Title,
|
||||
MbzID: row.MusicBrainzID,
|
||||
ID: row.ID,
|
||||
Image: row.Image,
|
||||
ListenCount: row.ListenCount,
|
||||
AlbumID: row.ReleaseID,
|
||||
Artists: artists,
|
||||
}
|
||||
tracks[i] = t
|
||||
}
|
||||
count, err = d.q.CountTopTracksByArtist(ctx, repository.CountTopTracksByArtistParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
ArtistID: int32(opts.ArtistID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
|
||||
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
|
||||
rows, err := d.q.GetTopTracksPaginated(ctx, repository.GetTopTracksPaginatedParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
Limit: int32(opts.Limit),
|
||||
Offset: int32(offset),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tracks = make([]*models.Track, len(rows))
|
||||
for i, row := range rows {
|
||||
artists := make([]models.SimpleArtist, 0)
|
||||
err = json.Unmarshal(row.Artists, &artists)
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
|
||||
artists = nil
|
||||
}
|
||||
t := &models.Track{
|
||||
Title: row.Title,
|
||||
MbzID: row.MusicBrainzID,
|
||||
ID: row.ID,
|
||||
Image: row.Image,
|
||||
ListenCount: row.ListenCount,
|
||||
AlbumID: row.ReleaseID,
|
||||
Artists: artists,
|
||||
}
|
||||
tracks[i] = t
|
||||
}
|
||||
count, err = d.q.CountTopTracks(ctx, repository.CountTopTracksParams{
|
||||
ListenedAt: t1,
|
||||
ListenedAt_2: t2,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug().Msgf("Database responded with %d tracks out of a total %d", len(rows), count)
|
||||
}
|
||||
|
||||
return &db.PaginatedResponse[*models.Track]{
|
||||
Items: tracks,
|
||||
TotalCount: count,
|
||||
ItemsPerPage: int32(opts.Limit),
|
||||
HasNextPage: int64(offset+len(tracks)) < count,
|
||||
CurrentPage: int32(opts.Page),
|
||||
}, nil
|
||||
}
|
||||
118
internal/db/psql/top_tracks_test.go
Normal file
118
internal/db/psql/top_tracks_test.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetTopTracksPaginated(t *testing.T) {
|
||||
testDataForTopItems(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test valid
|
||||
resp, err := store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 4)
|
||||
assert.Equal(t, int64(4), resp.TotalCount)
|
||||
assert.Equal(t, "Track One", resp.Items[0].Title)
|
||||
assert.Equal(t, "Track Two", resp.Items[1].Title)
|
||||
assert.Equal(t, "Track Three", resp.Items[2].Title)
|
||||
assert.Equal(t, "Track Four", resp.Items[3].Title)
|
||||
// ensure artists are included
|
||||
require.Len(t, resp.Items[0].Artists, 1)
|
||||
assert.Equal(t, "Artist One", resp.Items[0].Artists[0].Name)
|
||||
|
||||
// Test pagination
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, "Track Two", resp.Items[0].Title)
|
||||
|
||||
// Test page out of range
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp.Items)
|
||||
assert.False(t, resp.HasNextPage)
|
||||
|
||||
// Test invalid inputs
|
||||
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test specify period
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 0) // empty
|
||||
assert.Equal(t, int64(0), resp.TotalCount)
|
||||
// should default to PeriodDay
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 0) // empty
|
||||
assert.Equal(t, int64(0), resp.TotalCount)
|
||||
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Track Four", resp.Items[0].Title)
|
||||
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 2)
|
||||
assert.Equal(t, int64(2), resp.TotalCount)
|
||||
assert.Equal(t, "Track Three", resp.Items[0].Title)
|
||||
assert.Equal(t, "Track Four", resp.Items[1].Title)
|
||||
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 3)
|
||||
assert.Equal(t, int64(3), resp.TotalCount)
|
||||
assert.Equal(t, "Track Two", resp.Items[0].Title)
|
||||
assert.Equal(t, "Track Three", resp.Items[1].Title)
|
||||
assert.Equal(t, "Track Four", resp.Items[2].Title)
|
||||
|
||||
// Test filter by artists and releases
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, ArtistID: 1})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Track One", resp.Items[0].Title)
|
||||
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Track Two", resp.Items[0].Title)
|
||||
// when both artistID and albumID are specified, artist id is ignored
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2, ArtistID: 1})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Track Two", resp.Items[0].Title)
|
||||
|
||||
// Test specify dates
|
||||
|
||||
testDataAbsoluteListenTimes(t)
|
||||
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Year: 2023})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Track One", resp.Items[0].Title)
|
||||
|
||||
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Items, 1)
|
||||
assert.Equal(t, int64(1), resp.TotalCount)
|
||||
assert.Equal(t, "Track Two", resp.Items[0].Title)
|
||||
|
||||
// invalid, year required with month
|
||||
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Month: 10})
|
||||
require.Error(t, err)
|
||||
}
|
||||
298
internal/db/psql/track.go
Normal file
298
internal/db/psql/track.go
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/gabehf/koito/internal/utils"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
func (d *Psql) GetTrack(ctx context.Context, opts db.GetTrackOpts) (*models.Track, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
var track models.Track
|
||||
|
||||
if opts.ID != 0 {
|
||||
l.Debug().Msgf("Fetching track from DB with id %d", opts.ID)
|
||||
t, err := d.q.GetTrack(ctx, opts.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
track = models.Track{
|
||||
ID: t.ID,
|
||||
MbzID: t.MusicBrainzID,
|
||||
Title: t.Title,
|
||||
AlbumID: t.ReleaseID,
|
||||
Image: t.Image,
|
||||
Duration: t.Duration,
|
||||
}
|
||||
} else if opts.MusicBrainzID != uuid.Nil {
|
||||
l.Debug().Msgf("Fetching track from DB with MusicBrainz ID %s", opts.MusicBrainzID)
|
||||
t, err := d.q.GetTrackByMbzID(ctx, &opts.MusicBrainzID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
track = models.Track{
|
||||
ID: t.ID,
|
||||
MbzID: t.MusicBrainzID,
|
||||
Title: t.Title,
|
||||
AlbumID: t.ReleaseID,
|
||||
Duration: t.Duration,
|
||||
}
|
||||
} else if len(opts.ArtistIDs) > 0 {
|
||||
l.Debug().Msgf("Fetching track from DB with title '%s' and artist id(s) '%v'", opts.Title, opts.ArtistIDs)
|
||||
t, err := d.q.GetTrackByTitleAndArtists(ctx, repository.GetTrackByTitleAndArtistsParams{
|
||||
Title: opts.Title,
|
||||
Column2: opts.ArtistIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
track = models.Track{
|
||||
ID: t.ID,
|
||||
MbzID: t.MusicBrainzID,
|
||||
Title: t.Title,
|
||||
AlbumID: t.ReleaseID,
|
||||
Duration: t.Duration,
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("insufficient information to get track")
|
||||
}
|
||||
|
||||
count, err := d.q.CountListensFromTrack(ctx, repository.CountListensFromTrackParams{
|
||||
ListenedAt: time.Unix(0, 0),
|
||||
ListenedAt_2: time.Now(),
|
||||
TrackID: track.ID,
|
||||
})
|
||||
if err != nil {
|
||||
l.Err(err).Msgf("Failed to get listen count for track with id %d", track.ID)
|
||||
}
|
||||
|
||||
track.ListenCount = count
|
||||
|
||||
return &track, nil
|
||||
}
|
||||
|
||||
func (d *Psql) SaveTrack(ctx context.Context, opts db.SaveTrackOpts) (*models.Track, error) {
|
||||
// create track in DB
|
||||
l := logger.FromContext(ctx)
|
||||
var insertMbzID *uuid.UUID
|
||||
if opts.RecordingMbzID != uuid.Nil {
|
||||
insertMbzID = &opts.RecordingMbzID
|
||||
}
|
||||
if len(opts.ArtistIDs) < 1 {
|
||||
return nil, errors.New("required parameter 'ArtistIDs' missing")
|
||||
}
|
||||
for _, aid := range opts.ArtistIDs {
|
||||
if aid == 0 {
|
||||
return nil, errors.New("none of 'ArtistIDs' may be 0")
|
||||
}
|
||||
}
|
||||
if opts.AlbumID == 0 {
|
||||
return nil, errors.New("required parameter 'AlbumID' missing")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
l.Debug().Msgf("Inserting new track '%s' into DB", opts.Title)
|
||||
trackRow, err := qtx.InsertTrack(ctx, repository.InsertTrackParams{
|
||||
MusicBrainzID: insertMbzID,
|
||||
ReleaseID: opts.AlbumID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// insert associated artists
|
||||
for _, aid := range opts.ArtistIDs {
|
||||
err = qtx.AssociateArtistToTrack(ctx, repository.AssociateArtistToTrackParams{
|
||||
ArtistID: aid,
|
||||
TrackID: trackRow.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// insert primary alias
|
||||
err = qtx.InsertTrackAlias(ctx, repository.InsertTrackAliasParams{
|
||||
TrackID: trackRow.ID,
|
||||
Alias: opts.Title,
|
||||
Source: "Canonical",
|
||||
IsPrimary: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tx.Commit(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.Track{
|
||||
ID: trackRow.ID,
|
||||
MbzID: insertMbzID,
|
||||
Title: opts.Title,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) UpdateTrack(ctx context.Context, opts db.UpdateTrackOpts) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if opts.ID == 0 {
|
||||
return errors.New("track id not specified")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
if opts.MusicBrainzID != uuid.Nil {
|
||||
l.Debug().Msgf("Updating MusicBrainz ID for track %d", opts.ID)
|
||||
err := qtx.UpdateTrackMbzID(ctx, repository.UpdateTrackMbzIDParams{
|
||||
ID: opts.ID,
|
||||
MusicBrainzID: &opts.MusicBrainzID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opts.Duration != 0 {
|
||||
l.Debug().Msgf("Updating duration for track %d", opts.ID)
|
||||
err := qtx.UpdateTrackDuration(ctx, repository.UpdateTrackDurationParams{
|
||||
ID: opts.ID,
|
||||
Duration: opts.Duration,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) SaveTrackAliases(ctx context.Context, id int32, aliases []string, source string) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if id == 0 {
|
||||
return errors.New("track id not specified")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
existing, err := qtx.GetAllTrackAliases(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, v := range existing {
|
||||
aliases = append(aliases, v.Alias)
|
||||
}
|
||||
utils.Unique(&aliases)
|
||||
for _, alias := range aliases {
|
||||
if strings.TrimSpace(alias) == "" {
|
||||
return errors.New("aliases cannot be blank")
|
||||
}
|
||||
err = qtx.InsertTrackAlias(ctx, repository.InsertTrackAliasParams{
|
||||
Alias: strings.TrimSpace(alias),
|
||||
TrackID: id,
|
||||
Source: source,
|
||||
IsPrimary: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteTrack(ctx context.Context, id int32) error {
|
||||
return d.q.DeleteTrack(ctx, id)
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteTrackAlias(ctx context.Context, id int32, alias string) error {
|
||||
return d.q.DeleteTrackAlias(ctx, repository.DeleteTrackAliasParams{
|
||||
TrackID: id,
|
||||
Alias: alias,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Psql) GetAllTrackAliases(ctx context.Context, id int32) ([]models.Alias, error) {
|
||||
rows, err := d.q.GetAllTrackAliases(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aliases := make([]models.Alias, len(rows))
|
||||
for i, row := range rows {
|
||||
aliases[i] = models.Alias{
|
||||
ID: id,
|
||||
Alias: row.Alias,
|
||||
Source: row.Source,
|
||||
Primary: row.IsPrimary,
|
||||
}
|
||||
}
|
||||
return aliases, nil
|
||||
}
|
||||
|
||||
func (d *Psql) SetPrimaryTrackAlias(ctx context.Context, id int32, alias string) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if id == 0 {
|
||||
return errors.New("artist id not specified")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
// get all aliases
|
||||
aliases, err := qtx.GetAllTrackAliases(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
primary := ""
|
||||
exists := false
|
||||
for _, v := range aliases {
|
||||
if v.Alias == alias {
|
||||
exists = true
|
||||
}
|
||||
if v.IsPrimary {
|
||||
primary = v.Alias
|
||||
}
|
||||
}
|
||||
if primary == alias {
|
||||
// no-op rename
|
||||
return nil
|
||||
}
|
||||
if !exists {
|
||||
return errors.New("alias does not exist")
|
||||
}
|
||||
err = qtx.SetTrackAliasPrimaryStatus(ctx, repository.SetTrackAliasPrimaryStatusParams{
|
||||
TrackID: id,
|
||||
Alias: alias,
|
||||
IsPrimary: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = qtx.SetTrackAliasPrimaryStatus(ctx, repository.SetTrackAliasPrimaryStatusParams{
|
||||
TrackID: id,
|
||||
Alias: primary,
|
||||
IsPrimary: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
213
internal/db/psql/track_test.go
Normal file
213
internal/db/psql/track_test.go
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testDataForTracks(t *testing.T) {
|
||||
truncateTestData(t)
|
||||
|
||||
// Insert artists
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO artists (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001'),
|
||||
('00000000-0000-0000-0000-000000000002')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert artist aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
|
||||
VALUES (1, 'Artist One', 'Testing', true),
|
||||
(2, 'Artist Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert release groups
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO releases (musicbrainz_id)
|
||||
VALUES ('00000000-0000-0000-0000-000000000011'),
|
||||
('00000000-0000-0000-0000-000000000022')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert release aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
|
||||
VALUES (1, 'Release Group One', 'Testing', true),
|
||||
(2, 'Release Group Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert tracks
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO tracks (musicbrainz_id, release_id)
|
||||
VALUES ('11111111-1111-1111-1111-111111111111', 1),
|
||||
('22222222-2222-2222-2222-222222222222', 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert track aliases
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
|
||||
VALUES (1, 'Track One', 'Testing', true),
|
||||
(2, 'Track Two', 'Testing', true)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Associate tracks with artists
|
||||
err = store.Exec(context.Background(),
|
||||
`INSERT INTO artist_tracks (artist_id, track_id)
|
||||
VALUES (1, 1), (2, 2)`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetTrack(t *testing.T) {
|
||||
testDataForTracks(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test GetTrack by ID
|
||||
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int32(1), track.ID)
|
||||
assert.Equal(t, "Track One", track.Title)
|
||||
assert.Equal(t, uuid.MustParse("11111111-1111-1111-1111-111111111111"), *track.MbzID)
|
||||
|
||||
// Test GetTrack by MusicBrainzID
|
||||
track, err = store.GetTrack(ctx, db.GetTrackOpts{MusicBrainzID: uuid.MustParse("22222222-2222-2222-2222-222222222222")})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int32(2), track.ID)
|
||||
assert.Equal(t, "Track Two", track.Title)
|
||||
|
||||
// Test GetTrack by Title and ArtistIDs
|
||||
track, err = store.GetTrack(ctx, db.GetTrackOpts{
|
||||
Title: "Track One",
|
||||
ArtistIDs: []int32{1},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int32(1), track.ID)
|
||||
assert.Equal(t, "Track One", track.Title)
|
||||
|
||||
// Test GetTrack with insufficient information
|
||||
_, err = store.GetTrack(ctx, db.GetTrackOpts{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
func TestSaveTrack(t *testing.T) {
|
||||
testDataForTracks(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test SaveTrack with valid inputs
|
||||
track, err := store.SaveTrack(ctx, db.SaveTrackOpts{
|
||||
Title: "New Track",
|
||||
ArtistIDs: []int32{1},
|
||||
RecordingMbzID: uuid.MustParse("33333333-3333-3333-3333-333333333333"),
|
||||
AlbumID: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "New Track", track.Title)
|
||||
assert.Equal(t, uuid.MustParse("33333333-3333-3333-3333-333333333333"), *track.MbzID)
|
||||
|
||||
// Verify artist associations exist
|
||||
exists, err := store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM artist_tracks
|
||||
WHERE artist_id = $1 AND track_id = $2
|
||||
)`, 1, track.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected artist association to exist")
|
||||
|
||||
// Verify alias exists
|
||||
exists, err = store.RowExists(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM track_aliases
|
||||
WHERE track_id = $1 AND is_primary = true
|
||||
)`, track.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "expected primary alias to exist")
|
||||
|
||||
// Test SaveTrack with missing ArtistIDs
|
||||
_, err = store.SaveTrack(ctx, db.SaveTrackOpts{
|
||||
Title: "Invalid Track",
|
||||
ArtistIDs: []int32{},
|
||||
RecordingMbzID: uuid.MustParse("44444444-4444-4444-4444-444444444444"),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test SaveTrack with invalid ArtistIDs
|
||||
_, err = store.SaveTrack(ctx, db.SaveTrackOpts{
|
||||
Title: "Invalid Track",
|
||||
ArtistIDs: []int32{0},
|
||||
RecordingMbzID: uuid.MustParse("55555555-5555-5555-5555-555555555555"),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUpdateTrack(t *testing.T) {
|
||||
testDataForTracks(t)
|
||||
ctx := context.Background()
|
||||
|
||||
newMbzID := uuid.MustParse("66666666-6666-6666-6666-666666666666")
|
||||
newDuration := 100
|
||||
err := store.UpdateTrack(ctx, db.UpdateTrackOpts{
|
||||
ID: 1,
|
||||
MusicBrainzID: newMbzID,
|
||||
Duration: int32(newDuration),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the update
|
||||
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newMbzID, *track.MbzID)
|
||||
require.EqualValues(t, newDuration, track.Duration)
|
||||
|
||||
// Test UpdateTrack with missing ID
|
||||
err = store.UpdateTrack(ctx, db.UpdateTrackOpts{
|
||||
ID: 0,
|
||||
MusicBrainzID: newMbzID,
|
||||
Duration: int32(newDuration),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test UpdateTrack with nil MusicBrainz ID
|
||||
err = store.UpdateTrack(ctx, db.UpdateTrackOpts{
|
||||
ID: 1,
|
||||
MusicBrainzID: uuid.Nil,
|
||||
Duration: int32(newDuration),
|
||||
})
|
||||
assert.NoError(t, err) // No update should occur
|
||||
}
|
||||
|
||||
func TestTrackAliases(t *testing.T) {
|
||||
testDataForTracks(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.SaveTrackAliases(ctx, 1, []string{"Alias One", "Alias Two"}, "Testing")
|
||||
require.NoError(t, err)
|
||||
aliases, err := store.GetAllTrackAliases(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, aliases, 3)
|
||||
|
||||
err = store.SetPrimaryTrackAlias(ctx, 1, "Alias One")
|
||||
require.NoError(t, err)
|
||||
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Alias One", track.Title)
|
||||
|
||||
err = store.SetPrimaryTrackAlias(ctx, 1, "Fake Alias")
|
||||
require.Error(t, err)
|
||||
|
||||
store.SetPrimaryTrackAlias(ctx, 1, "Track One")
|
||||
}
|
||||
|
||||
func TestDeleteTrack(t *testing.T) {
|
||||
testDataForTracks(t)
|
||||
ctx := context.Background()
|
||||
|
||||
err := store.DeleteTrack(ctx, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.Count(ctx, `SELECT * FROM tracks WHERE id = 2`)
|
||||
require.ErrorIs(t, err, pgx.ErrNoRows) // no rows error
|
||||
}
|
||||
219
internal/db/psql/user.go
Normal file
219
internal/db/psql/user.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Returns nil, nil when no database entries are found
|
||||
func (d *Psql) GetUserByUsername(ctx context.Context, username string) (*models.User, error) {
|
||||
row, err := d.q.GetUserByUsername(ctx, strings.ToLower(username))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.User{
|
||||
ID: row.ID,
|
||||
Username: row.Username,
|
||||
Password: row.Password,
|
||||
Role: models.UserRole(row.Role),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Returns nil, nil when no database entries are found
|
||||
func (d *Psql) GetUserByApiKey(ctx context.Context, key string) (*models.User, error) {
|
||||
row, err := d.q.GetUserByApiKey(ctx, key)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.User{
|
||||
ID: row.ID,
|
||||
Username: row.Username,
|
||||
Password: row.Password,
|
||||
Role: models.UserRole(row.Role),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) SaveUser(ctx context.Context, opts db.SaveUserOpts) (*models.User, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
err := ValidateUsername(opts.Username)
|
||||
if err != nil {
|
||||
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
|
||||
return nil, err
|
||||
}
|
||||
pw, err := ValidateAndNormalizePassword(opts.Password)
|
||||
if err != nil {
|
||||
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
|
||||
return nil, err
|
||||
}
|
||||
if opts.Role == "" {
|
||||
opts.Role = models.UserRoleUser
|
||||
}
|
||||
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to generate hashed password")
|
||||
return nil, err
|
||||
}
|
||||
u, err := d.q.InsertUser(ctx, repository.InsertUserParams{
|
||||
Username: strings.ToLower(opts.Username),
|
||||
Password: hashPw,
|
||||
Role: repository.Role(opts.Role),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.User{
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
Role: models.UserRole(u.Role),
|
||||
}, nil
|
||||
}
|
||||
func (d *Psql) SaveApiKey(ctx context.Context, opts db.SaveApiKeyOpts) (*models.ApiKey, error) {
|
||||
row, err := d.q.InsertApiKey(ctx, repository.InsertApiKeyParams{
|
||||
Key: opts.Key,
|
||||
Label: opts.Label,
|
||||
UserID: opts.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.ApiKey{
|
||||
ID: row.ID,
|
||||
UserID: row.UserID,
|
||||
Key: row.Key,
|
||||
Label: row.Label,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) UpdateUser(ctx context.Context, opts db.UpdateUserOpts) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if opts.ID == 0 {
|
||||
return errors.New("user id is required")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
if opts.Username != "" {
|
||||
err := ValidateUsername(opts.Username)
|
||||
if err != nil {
|
||||
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
|
||||
return err
|
||||
}
|
||||
err = qtx.UpdateUserUsername(ctx, repository.UpdateUserUsernameParams{
|
||||
ID: opts.ID,
|
||||
Username: opts.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opts.Password != "" {
|
||||
pw, err := ValidateAndNormalizePassword(opts.Password)
|
||||
if err != nil {
|
||||
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
|
||||
return err
|
||||
}
|
||||
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to generate hashed password")
|
||||
return err
|
||||
}
|
||||
err = qtx.UpdateUserPassword(ctx, repository.UpdateUserPasswordParams{
|
||||
ID: opts.ID,
|
||||
Password: hashPw,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) GetApiKeysByUserID(ctx context.Context, id int32) ([]models.ApiKey, error) {
|
||||
rows, err := d.q.GetAllApiKeysByUserID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys := make([]models.ApiKey, len(rows))
|
||||
for i, row := range rows {
|
||||
keys[i] = models.ApiKey{
|
||||
ID: row.ID,
|
||||
Key: row.Key,
|
||||
Label: row.Label,
|
||||
UserID: row.UserID,
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (d *Psql) UpdateApiKeyLabel(ctx context.Context, opts db.UpdateApiKeyLabelOpts) error {
|
||||
return d.q.UpdateApiKeyLabel(ctx, repository.UpdateApiKeyLabelParams{
|
||||
ID: opts.ID,
|
||||
Label: opts.Label,
|
||||
UserID: opts.UserID,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteApiKey(ctx context.Context, id int32) error {
|
||||
return d.q.DeleteApiKey(ctx, id)
|
||||
}
|
||||
|
||||
func (d *Psql) CountUsers(ctx context.Context) (int64, error) {
|
||||
return d.q.CountUsers(ctx)
|
||||
}
|
||||
|
||||
const (
|
||||
maxUsernameLength = 32
|
||||
minUsernameLength = 1
|
||||
maxPasswordLength = 128
|
||||
minPasswordLength = 8
|
||||
)
|
||||
|
||||
var usernameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
|
||||
|
||||
func ValidateUsername(username string) error {
|
||||
length := utf8.RuneCountInString(username)
|
||||
if length < minUsernameLength || length > maxUsernameLength {
|
||||
return errors.New("username must be between 1 and 32 characters")
|
||||
}
|
||||
if !usernameRegex.MatchString(username) {
|
||||
return errors.New("username can only contain [a-zA-Z0-9_.-]")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAndNormalizePassword(password string) (string, error) {
|
||||
length := utf8.RuneCountInString(password)
|
||||
if length < minPasswordLength {
|
||||
return "", errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
if length > maxPasswordLength {
|
||||
var truncated []rune
|
||||
for i, r := range password {
|
||||
if i >= maxPasswordLength {
|
||||
break
|
||||
}
|
||||
truncated = append(truncated, r)
|
||||
}
|
||||
password = string(truncated)
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
199
internal/db/psql/user_test.go
Normal file
199
internal/db/psql/user_test.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
package psql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func setupTestDataForUsers(t *testing.T) {
|
||||
truncateTestDataForUsers(t)
|
||||
// Insert additional test users
|
||||
err := store.Exec(context.Background(),
|
||||
`INSERT INTO users (username, password, role)
|
||||
VALUES ('test_user', $1, 'user'),
|
||||
('admin_user', $1, 'admin')`, []byte("hashed_password"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func truncateTestDataForUsers(t *testing.T) {
|
||||
err := store.Exec(context.Background(),
|
||||
`DELETE FROM users WHERE id NOT IN (1)`,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = store.Exec(context.Background(),
|
||||
`ALTER SEQUENCE users_id_seq RESTART WITH 2`,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = store.Exec(context.Background(),
|
||||
`TRUNCATE api_keys RESTART IDENTITY CASCADE`,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetUserByUsername(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForUsers(t)
|
||||
|
||||
// Test fetching an existing user
|
||||
user, err := store.GetUserByUsername(ctx, "test_user")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "test_user", user.Username)
|
||||
assert.Equal(t, "user", string(user.Role))
|
||||
|
||||
// Test fetching a non-existent user
|
||||
user, err = store.GetUserByUsername(ctx, "nonexistent_user")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
func TestGetUserByApiKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForUsers(t)
|
||||
|
||||
// Insert an API key for the test user
|
||||
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('test_key', 'Test Key', 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test fetching a user by API key
|
||||
user, err := store.GetUserByApiKey(ctx, "test_key")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, int32(2), user.ID)
|
||||
assert.Equal(t, "test_user", user.Username)
|
||||
|
||||
// Test fetching a user with a non-existent API key
|
||||
user, err = store.GetUserByApiKey(ctx, "nonexistent_key")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
func TestSaveUser(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForUsers(t)
|
||||
|
||||
// Save a new user
|
||||
opts := db.SaveUserOpts{
|
||||
Username: "new_user",
|
||||
Password: "secure_password",
|
||||
Role: "user",
|
||||
}
|
||||
user, err := store.SaveUser(ctx, opts)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "new_user", user.Username)
|
||||
assert.Equal(t, "user", string(user.Role))
|
||||
|
||||
// Verify the password was hashed
|
||||
var hashedPassword []byte
|
||||
err = store.QueryRow(ctx, `SELECT password FROM users WHERE username = $1`, "new_user").Scan(&hashedPassword)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, bcrypt.CompareHashAndPassword(hashedPassword, []byte(opts.Password)))
|
||||
|
||||
// Test validation failures
|
||||
_, err = store.SaveUser(ctx, db.SaveUserOpts{
|
||||
Username: "Q!@JH(F_H@#!*HF#*)&@",
|
||||
Password: "testpassword12345",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
_, err = store.SaveUser(ctx, db.SaveUserOpts{
|
||||
Username: "test_user",
|
||||
Password: "<3",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSaveApiKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForUsers(t)
|
||||
|
||||
// Save an API key for the test user
|
||||
label := "New API Key"
|
||||
opts := db.SaveApiKeyOpts{
|
||||
Key: "new_api_key",
|
||||
Label: label,
|
||||
UserID: 2,
|
||||
}
|
||||
_, err := store.SaveApiKey(ctx, opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the API key was saved
|
||||
count, err := store.Count(ctx, `SELECT COUNT(*) FROM api_keys WHERE key = $1 AND user_id = $2`, opts.Key, opts.UserID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestGetApiKeysByUserID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForUsers(t)
|
||||
|
||||
// Insert API keys for the test user
|
||||
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES
|
||||
('key1', 'Key 1', 2),
|
||||
('key2', 'Key 2', 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch API keys for the test user
|
||||
keys, err := store.GetApiKeysByUserID(ctx, 2)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 2)
|
||||
assert.Equal(t, "key1", keys[0].Key)
|
||||
assert.Equal(t, "key2", keys[1].Key)
|
||||
}
|
||||
|
||||
func TestUpdateApiKeyLabel(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForUsers(t)
|
||||
|
||||
// Insert an API key for the test user
|
||||
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('key_to_update', 'Old Label', 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the API key label
|
||||
opts := db.UpdateApiKeyLabelOpts{
|
||||
ID: 1,
|
||||
Label: "Updated Label",
|
||||
UserID: 2,
|
||||
}
|
||||
err = store.UpdateApiKeyLabel(ctx, opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the label was updated
|
||||
var label string
|
||||
err = store.QueryRow(ctx, `SELECT label FROM api_keys WHERE id = $1`, opts.ID).Scan(&label)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Label", label)
|
||||
}
|
||||
|
||||
func TestDeleteApiKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForUsers(t)
|
||||
|
||||
// Insert an API key for the test user
|
||||
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('key_to_delete', 'Label', 2)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the API key
|
||||
err = store.DeleteApiKey(ctx, 1) // Assuming the ID is auto-generated and starts from 1
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the API key was deleted
|
||||
count, err := store.Count(ctx, `SELECT COUNT(*) FROM api_keys WHERE id = $1`, 1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestCountUsers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setupTestDataForUsers(t)
|
||||
|
||||
// Count the number of users
|
||||
count, err := store.Count(ctx, `SELECT COUNT(*) FROM users`)
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, count, 3) // Special user + test users
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue