chore: initial public commit

This commit is contained in:
Gabe Farrell 2025-06-11 19:45:39 -04:00
commit fc9054b78c
250 changed files with 32809 additions and 0 deletions

312
internal/db/psql/album.go Normal file
View file

@ -0,0 +1,312 @@
package psql
import (
"context"
"errors"
"strings"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
func (d *Psql) GetAlbum(ctx context.Context, opts db.GetAlbumOpts) (*models.Album, error) {
l := logger.FromContext(ctx)
var row repository.ReleasesWithTitle
var err error
if opts.ID != 0 {
l.Debug().Msgf("Fetching album from DB with id %d", opts.ID)
row, err = d.q.GetRelease(ctx, opts.ID)
} else if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Fetching album from DB with MusicBrainz Release ID %s", opts.MusicBrainzID)
row, err = d.q.GetReleaseByMbzID(ctx, &opts.MusicBrainzID)
} else if opts.ArtistID != 0 && opts.Title != "" {
l.Debug().Msgf("Fetching album from DB with artist_id %d and title %s", opts.ArtistID, opts.Title)
row, err = d.q.GetReleaseByArtistAndTitle(ctx, repository.GetReleaseByArtistAndTitleParams{
ArtistID: opts.ArtistID,
Title: opts.Title,
})
} else if opts.ArtistID != 0 && len(opts.Titles) > 0 {
l.Debug().Msgf("Fetching release group from DB with artist_id %d and titles %v", opts.ArtistID, opts.Titles)
row, err = d.q.GetReleaseByArtistAndTitles(ctx, repository.GetReleaseByArtistAndTitlesParams{
ArtistID: opts.ArtistID,
Column1: opts.Titles,
})
} else {
return nil, errors.New("insufficient information to get album")
}
if err != nil {
return nil, err
}
count, err := d.q.CountListensFromRelease(ctx, repository.CountListensFromReleaseParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
ReleaseID: row.ID,
})
if err != nil {
return nil, err
}
return &models.Album{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
Image: row.Image,
VariousArtists: row.VariousArtists,
ListenCount: count,
}, nil
}
func (d *Psql) SaveAlbum(ctx context.Context, opts db.SaveAlbumOpts) (*models.Album, error) {
l := logger.FromContext(ctx)
var insertMbzID *uuid.UUID
var insertImage *uuid.UUID
if opts.MusicBrainzID != uuid.Nil {
insertMbzID = &opts.MusicBrainzID
}
if opts.Image != uuid.Nil {
insertImage = &opts.Image
}
if len(opts.ArtistIDs) < 1 {
return nil, errors.New("required parameter 'ArtistIDs' missing")
}
for _, aid := range opts.ArtistIDs {
if aid == 0 {
return nil, errors.New("none of 'ArtistIDs' may be 0")
}
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return nil, err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
l.Debug().Msgf("Inserting release '%s' into DB", opts.Title)
r, err := qtx.InsertRelease(ctx, repository.InsertReleaseParams{
MusicBrainzID: insertMbzID,
VariousArtists: opts.VariousArtists,
Image: insertImage,
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
})
if err != nil {
return nil, err
}
for _, artistId := range opts.ArtistIDs {
l.Debug().Msgf("Associating release '%s' to artist with ID %d", opts.Title, artistId)
err = qtx.AssociateArtistToRelease(ctx, repository.AssociateArtistToReleaseParams{
ArtistID: artistId,
ReleaseID: r.ID,
})
if err != nil {
return nil, err
}
}
l.Debug().Msgf("Saving canonical alias %s for release %d", opts.Title, r.ID)
err = qtx.InsertReleaseAlias(ctx, repository.InsertReleaseAliasParams{
ReleaseID: r.ID,
Alias: opts.Title,
Source: "Canonical",
IsPrimary: true,
})
if err != nil {
l.Err(err).Msgf("Failed to save canonical alias for album %d", r.ID)
}
err = tx.Commit(ctx)
if err != nil {
return nil, err
}
return &models.Album{
ID: r.ID,
MbzID: r.MusicBrainzID,
Title: opts.Title,
Image: r.Image,
VariousArtists: r.VariousArtists,
}, nil
}
func (d *Psql) AddArtistsToAlbum(ctx context.Context, opts db.AddArtistsToAlbumOpts) error {
l := logger.FromContext(ctx)
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
for _, id := range opts.ArtistIDs {
err := qtx.AssociateArtistToRelease(ctx, repository.AssociateArtistToReleaseParams{
ReleaseID: opts.AlbumID,
ArtistID: id,
})
if err != nil {
l.Error().Err(err).Msgf("Failed to associate release %d with artist %d", opts.AlbumID, id)
}
}
return tx.Commit(ctx)
}
func (d *Psql) UpdateAlbum(ctx context.Context, opts db.UpdateAlbumOpts) error {
l := logger.FromContext(ctx)
if opts.ID == 0 {
return errors.New("missing album id")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Updating release with ID %d with MusicBrainz ID %s", opts.ID, opts.MusicBrainzID)
err := qtx.UpdateReleaseMbzID(ctx, repository.UpdateReleaseMbzIDParams{
ID: opts.ID,
MusicBrainzID: &opts.MusicBrainzID,
})
if err != nil {
return err
}
}
if opts.Image != uuid.Nil {
l.Debug().Msgf("Updating release with ID %d with image %s", opts.ID, opts.Image)
err := qtx.UpdateReleaseImage(ctx, repository.UpdateReleaseImageParams{
ID: opts.ID,
Image: &opts.Image,
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) SaveAlbumAliases(ctx context.Context, id int32, aliases []string, source string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("album id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
existing, err := qtx.GetAllReleaseAliases(ctx, id)
if err != nil {
return err
}
for _, v := range existing {
aliases = append(aliases, v.Alias)
}
utils.Unique(&aliases)
for _, alias := range aliases {
if strings.TrimSpace(alias) == "" {
return errors.New("aliases cannot be blank")
}
err = qtx.InsertReleaseAlias(ctx, repository.InsertReleaseAliasParams{
Alias: strings.TrimSpace(alias),
ReleaseID: id,
Source: source,
IsPrimary: false,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) DeleteAlbum(ctx context.Context, id int32) error {
return d.q.DeleteRelease(ctx, id)
}
func (d *Psql) DeleteAlbumAlias(ctx context.Context, id int32, alias string) error {
return d.q.DeleteReleaseAlias(ctx, repository.DeleteReleaseAliasParams{
ReleaseID: id,
Alias: alias,
})
}
func (d *Psql) GetAllAlbumAliases(ctx context.Context, id int32) ([]models.Alias, error) {
rows, err := d.q.GetAllReleaseAliases(ctx, id)
if err != nil {
return nil, err
}
aliases := make([]models.Alias, len(rows))
for i, row := range rows {
aliases[i] = models.Alias{
ID: id,
Alias: row.Alias,
Source: row.Source,
Primary: row.IsPrimary,
}
}
return aliases, nil
}
func (d *Psql) SetPrimaryAlbumAlias(ctx context.Context, id int32, alias string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
// get all aliases
aliases, err := qtx.GetAllReleaseAliases(ctx, id)
if err != nil {
return err
}
primary := ""
exists := false
for _, v := range aliases {
if v.Alias == alias {
exists = true
}
if v.IsPrimary {
primary = v.Alias
}
}
if primary == alias {
// no-op rename
return nil
}
if !exists {
return errors.New("alias does not exist")
}
err = qtx.SetReleaseAliasPrimaryStatus(ctx, repository.SetReleaseAliasPrimaryStatusParams{
ReleaseID: id,
Alias: alias,
IsPrimary: true,
})
if err != nil {
return err
}
err = qtx.SetReleaseAliasPrimaryStatus(ctx, repository.SetReleaseAliasPrimaryStatusParams{
ReleaseID: id,
Alias: primary,
IsPrimary: false,
})
if err != nil {
return err
}
return tx.Commit(ctx)
}

View file

@ -0,0 +1,319 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/catalog"
"github.com/gabehf/koito/internal/db"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func truncateTestData(t *testing.T) {
err := store.Exec(context.Background(),
`TRUNCATE
artists,
artist_aliases,
tracks,
artist_tracks,
releases,
artist_releases,
release_aliases,
listens
RESTART IDENTITY CASCADE`)
require.NoError(t, err)
}
func testDataForRelease(t *testing.T) {
truncateTestData(t)
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'ATARASHII GAKKO!', 'MusicBrainz', true)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (2, 'Masayuki Suzuki', 'MusicBrainz', true)`)
require.NoError(t, err)
}
func TestGetAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test data
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Release Group",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Test GetAlbum by ID
result, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
require.NoError(t, err)
assert.Equal(t, rg.ID, result.ID)
assert.Equal(t, "Test Release Group", result.Title)
// Test GetAlbum with insufficient information
_, err = store.GetAlbum(ctx, db.GetAlbumOpts{})
assert.Error(t, err)
truncateTestData(t)
}
func TestSaveAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Save release group with artist IDs
artistIDs := []int32{1, 2}
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "New Release Group",
ArtistIDs: artistIDs,
})
require.NoError(t, err)
// Verify release group was saved
assert.Equal(t, "New Release Group", rg.Title)
// Verify release was created for release group
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM releases_with_title
WHERE title = $1 AND id = $2
)`, "New Release Group", rg.ID)
require.NoError(t, err)
assert.True(t, exists, "expected release to exist")
// Verify artist associations were created for release group
for _, aid := range artistIDs {
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_releases
WHERE artist_id = $1 AND release_id = $2
)`, aid, rg.ID)
require.NoError(t, err)
assert.True(t, exists, "expected artist association to exist")
}
truncateTestData(t)
}
func TestUpdateAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Old Title",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
newMbzID := uuid.New()
imgid := uuid.New()
err = store.UpdateAlbum(ctx, db.UpdateAlbumOpts{
ID: rg.ID,
MusicBrainzID: newMbzID,
Image: imgid,
ImageSrc: catalog.ImageSourceUserUpload,
})
require.NoError(t, err)
result, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
require.NoError(t, err)
assert.Equal(t, newMbzID, *result.MbzID)
assert.Equal(t, imgid, *result.Image)
truncateTestData(t)
}
func TestAddArtistsToAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test album
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Album",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Add additional artists to the album
err = store.AddArtistsToAlbum(ctx, db.AddArtistsToAlbumOpts{
AlbumID: rg.ID,
ArtistIDs: []int32{2},
})
require.NoError(t, err)
// Verify artist associations were created
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_releases
WHERE artist_id = $1 AND release_id = $2
)`, 2, rg.ID)
require.NoError(t, err)
assert.True(t, exists, "expected artist association to exist")
truncateTestData(t)
}
func TestSaveAlbumAliases(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test album
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Album",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Save aliases for the album
aliases := []string{"Alias 1", "Alias 2"}
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
require.NoError(t, err)
// Verify aliases were saved
for _, alias := range aliases {
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM release_aliases
WHERE release_id = $1 AND alias = $2
)`, rg.ID, alias)
require.NoError(t, err)
assert.True(t, exists, "expected alias to exist")
}
err = store.SetPrimaryAlbumAlias(ctx, 1, "Alias 1")
require.NoError(t, err)
album, err := store.GetAlbum(ctx, db.GetAlbumOpts{ID: rg.ID})
require.NoError(t, err)
assert.Equal(t, "Alias 1", album.Title)
err = store.SetPrimaryAlbumAlias(ctx, 1, "Fake Alias")
require.Error(t, err)
store.SetPrimaryAlbumAlias(ctx, 1, "Album One")
truncateTestData(t)
}
func TestDeleteAlbum(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
testDataForTopItems(t)
// Delete the album
err := store.DeleteAlbum(ctx, 1)
require.NoError(t, err)
// Verify album was deleted
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM releases
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected album to be deleted")
// Verify album's track was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM tracks
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected album's tracks to be deleted")
// Verify album's listens was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM listens
WHERE track_id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected album's listens to be deleted")
truncateTestData(t)
}
func TestDeleteAlbumAlias(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test album
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Album",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Save aliases for the album
aliases := []string{"Alias 1", "Alias 2"}
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
require.NoError(t, err)
// Delete one alias
err = store.DeleteAlbumAlias(ctx, rg.ID, "Alias 1")
require.NoError(t, err)
// Verify alias was deleted
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM release_aliases
WHERE release_id = $1 AND alias = $2
)`, rg.ID, "Alias 1")
require.NoError(t, err)
assert.False(t, exists, "expected alias to be deleted")
// Verify other alias still exists
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM release_aliases
WHERE release_id = $1 AND alias = $2
)`, rg.ID, "Alias 2")
require.NoError(t, err)
assert.True(t, exists, "expected alias to still exist")
truncateTestData(t)
}
func TestGetAllAlbumAliases(t *testing.T) {
testDataForRelease(t)
ctx := context.Background()
// Insert test album
rg, err := store.SaveAlbum(ctx, db.SaveAlbumOpts{
Title: "Test Album",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
// Save aliases for the album
aliases := []string{"Alias 1", "Alias 2"}
err = store.SaveAlbumAliases(ctx, rg.ID, aliases, "TestSource")
require.NoError(t, err)
// Retrieve all aliases
result, err := store.GetAllAlbumAliases(ctx, rg.ID)
require.NoError(t, err)
assert.Len(t, result, len(aliases)+1) // new + canonical
for _, alias := range aliases {
found := false
for _, res := range result {
if res.Alias == alias {
found = true
break
}
}
assert.True(t, found, "expected alias to be retrieved")
}
truncateTestData(t)
}

309
internal/db/psql/artist.go Normal file
View file

@ -0,0 +1,309 @@
package psql
import (
"context"
"errors"
"strings"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
func (d *Psql) GetArtist(ctx context.Context, opts db.GetArtistOpts) (*models.Artist, error) {
l := logger.FromContext(ctx)
if opts.ID != 0 {
l.Debug().Msgf("Fetching artist from DB with id %d", opts.ID)
row, err := d.q.GetArtist(ctx, opts.ID)
if err != nil {
return nil, err
}
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
ArtistID: row.ID,
})
if err != nil {
return nil, err
}
return &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Aliases: row.Aliases,
Image: row.Image,
ListenCount: count,
}, nil
} else if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Fetching artist from DB with MusicBrainz ID %s", opts.MusicBrainzID)
row, err := d.q.GetArtistByMbzID(ctx, &opts.MusicBrainzID)
if err != nil {
return nil, err
}
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
ArtistID: row.ID,
})
if err != nil {
return nil, err
}
return &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Aliases: row.Aliases,
Image: row.Image,
ListenCount: count,
}, nil
} else if opts.Name != "" {
l.Debug().Msgf("Fetching artist from DB with name '%s'", opts.Name)
row, err := d.q.GetArtistByName(ctx, opts.Name)
if err != nil {
return nil, err
}
count, err := d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
ArtistID: row.ID,
})
if err != nil {
return nil, err
}
return &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Aliases: row.Aliases,
Image: row.Image,
ListenCount: count,
}, nil
} else {
return nil, errors.New("insufficient information to get artist")
}
}
// Inserts all unique aliases into the DB with specified source
func (d *Psql) SaveArtistAliases(ctx context.Context, id int32, aliases []string, source string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
existing, err := qtx.GetAllArtistAliases(ctx, id)
if err != nil {
return err
}
for _, v := range existing {
aliases = append(aliases, v.Alias)
}
utils.Unique(&aliases)
for _, alias := range aliases {
if strings.TrimSpace(alias) == "" {
return errors.New("aliases cannot be blank")
}
err = qtx.InsertArtistAlias(ctx, repository.InsertArtistAliasParams{
Alias: strings.TrimSpace(alias),
ArtistID: id,
Source: source,
IsPrimary: false,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) DeleteArtist(ctx context.Context, id int32) error {
return d.q.DeleteArtist(ctx, id)
}
// Equivalent to Psql.SaveArtist, then Psql.SaveMbzAliases
func (d *Psql) SaveArtist(ctx context.Context, opts db.SaveArtistOpts) (*models.Artist, error) {
l := logger.FromContext(ctx)
var insertMbzID *uuid.UUID
var insertImage *uuid.UUID
if opts.MusicBrainzID != uuid.Nil {
insertMbzID = &opts.MusicBrainzID
}
if opts.Image != uuid.Nil {
insertImage = &opts.Image
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return nil, err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
opts.Name = strings.TrimSpace(opts.Name)
if opts.Name == "" {
return nil, errors.New("name must not be blank")
}
l.Debug().Msgf("Inserting artist '%s' into DB", opts.Name)
a, err := qtx.InsertArtist(ctx, repository.InsertArtistParams{
MusicBrainzID: insertMbzID,
Image: insertImage,
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Inserting canonical alias '%s' into DB for artist with id %d", opts.Name, a.ID)
err = qtx.InsertArtistAlias(ctx, repository.InsertArtistAliasParams{
ArtistID: a.ID,
Alias: opts.Name,
Source: "Canonical",
IsPrimary: true,
})
if err != nil {
l.Error().Err(err).Msgf("Error inserting canonical alias for artist '%s'", opts.Name)
return nil, err
}
err = tx.Commit(ctx)
if err != nil {
l.Err(err).Msg("Failed to commit insert artist transaction")
return nil, err
}
artist := &models.Artist{
ID: a.ID,
Name: opts.Name,
Image: a.Image,
MbzID: a.MusicBrainzID,
Aliases: []string{opts.Name},
}
if len(opts.Aliases) > 0 {
l.Debug().Msgf("Inserting aliases '%v' into DB for artist '%s'", opts.Aliases, opts.Name)
err = d.SaveArtistAliases(ctx, a.ID, opts.Aliases, "MusicBrainz")
if err != nil {
return nil, err
}
artist.Aliases = opts.Aliases
}
return artist, nil
}
func (d *Psql) UpdateArtist(ctx context.Context, opts db.UpdateArtistOpts) error {
l := logger.FromContext(ctx)
if opts.ID == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Updating artist with id %d with MusicBrainz ID %s", opts.ID, opts.MusicBrainzID)
err := qtx.UpdateArtistMbzID(ctx, repository.UpdateArtistMbzIDParams{
ID: opts.ID,
MusicBrainzID: &opts.MusicBrainzID,
})
if err != nil {
return err
}
}
if opts.Image != uuid.Nil {
l.Debug().Msgf("Updating artist with id %d with image %s", opts.ID, opts.Image)
err = qtx.UpdateArtistImage(ctx, repository.UpdateArtistImageParams{
ID: opts.ID,
Image: &opts.Image,
ImageSource: pgtype.Text{String: opts.ImageSrc, Valid: opts.ImageSrc != ""},
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) DeleteArtistAlias(ctx context.Context, id int32, alias string) error {
return d.q.DeleteArtistAlias(ctx, repository.DeleteArtistAliasParams{
ArtistID: id,
Alias: alias,
})
}
func (d *Psql) GetAllArtistAliases(ctx context.Context, id int32) ([]models.Alias, error) {
rows, err := d.q.GetAllArtistAliases(ctx, id)
if err != nil {
return nil, err
}
aliases := make([]models.Alias, len(rows))
for i, row := range rows {
aliases[i] = models.Alias{
ID: id,
Alias: row.Alias,
Source: row.Source,
Primary: row.IsPrimary,
}
}
return aliases, nil
}
func (d *Psql) SetPrimaryArtistAlias(ctx context.Context, id int32, alias string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
// get all aliases
aliases, err := qtx.GetAllArtistAliases(ctx, id)
if err != nil {
return err
}
primary := ""
exists := false
for _, v := range aliases {
if v.Alias == alias {
exists = true
}
if v.IsPrimary {
primary = v.Alias
}
}
if primary == alias {
// no-op rename
return nil
}
if !exists {
return errors.New("alias does not exist")
}
err = qtx.SetArtistAliasPrimaryStatus(ctx, repository.SetArtistAliasPrimaryStatusParams{
ArtistID: id,
Alias: alias,
IsPrimary: true,
})
if err != nil {
return err
}
err = qtx.SetArtistAliasPrimaryStatus(ctx, repository.SetArtistAliasPrimaryStatusParams{
ArtistID: id,
Alias: primary,
IsPrimary: false,
})
if err != nil {
return err
}
return tx.Commit(ctx)
}

View file

@ -0,0 +1,247 @@
package psql_test
import (
"context"
"slices"
"testing"
"github.com/gabehf/koito/internal/catalog"
"github.com/gabehf/koito/internal/db"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetArtist(t *testing.T) {
ctx := context.Background()
mbzId := uuid.MustParse("00000000-0000-0000-0000-000000000001")
// Insert test data
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Test Artist",
MusicBrainzID: mbzId,
})
require.NoError(t, err)
// Test GetArtist by ID
result, err := store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
require.NoError(t, err)
assert.Equal(t, artist.ID, result.ID)
assert.Equal(t, "Test Artist", result.Name)
// Test GetArtist by Name
result, err = store.GetArtist(ctx, db.GetArtistOpts{Name: artist.Name})
require.NoError(t, err)
assert.Equal(t, artist.ID, result.ID)
// Test GetArtist by MusicBrainzID
result, err = store.GetArtist(ctx, db.GetArtistOpts{MusicBrainzID: mbzId})
require.NoError(t, err)
assert.Equal(t, artist.ID, result.ID)
// Test GetArtist with insufficient information
_, err = store.GetArtist(ctx, db.GetArtistOpts{})
assert.Error(t, err)
truncateTestData(t)
}
func TestSaveAliases(t *testing.T) {
ctx := context.Background()
// Insert test artist
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Alias Artist",
})
require.NoError(t, err)
// Save aliases
aliases := []string{"Alias1", "Alias2"}
err = store.SaveArtistAliases(ctx, artist.ID, aliases, "MusicBrainz")
require.NoError(t, err)
// Verify aliases were saved
for _, alias := range aliases {
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_aliases
WHERE artist_id = $1 AND alias = $2
)`, artist.ID, alias)
require.NoError(t, err)
assert.True(t, exists, "expected alias to exist")
}
err = store.SetPrimaryArtistAlias(ctx, 1, "Alias1")
require.NoError(t, err)
artist, err = store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
require.NoError(t, err)
assert.Equal(t, "Alias1", artist.Name)
err = store.SetPrimaryArtistAlias(ctx, 1, "Fake Alias")
require.Error(t, err)
truncateTestData(t)
}
func TestSaveArtist(t *testing.T) {
ctx := context.Background()
// Save artist with aliases
aliases := []string{"Alias1", "Alias2"}
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "New Artist",
Aliases: aliases,
})
require.NoError(t, err)
// Verify artist was saved
assert.Equal(t, "New Artist", artist.Name)
// Verify aliases were saved
for _, alias := range slices.Concat(aliases, []string{"New Artist"}) {
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_aliases
WHERE artist_id = $1 AND alias = $2
)`, artist.ID, alias)
require.NoError(t, err)
assert.True(t, exists, "expected alias '%s' to exist", alias)
}
truncateTestData(t)
}
func TestUpdateArtist(t *testing.T) {
ctx := context.Background()
// Insert test artist
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Old Name",
})
require.NoError(t, err)
imgid := uuid.New()
err = store.UpdateArtist(ctx, db.UpdateArtistOpts{
ID: artist.ID,
Image: imgid,
ImageSrc: catalog.ImageSourceUserUpload,
})
require.NoError(t, err)
result, err := store.GetArtist(ctx, db.GetArtistOpts{ID: artist.ID})
require.NoError(t, err)
assert.Equal(t, imgid, *result.Image)
truncateTestData(t)
}
func TestGetAllArtistAliases(t *testing.T) {
ctx := context.Background()
// Insert test artist
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Alias Artist",
Aliases: []string{"Alias1", "Alias2"},
})
require.NoError(t, err)
// Retrieve all aliases
result, err := store.GetAllArtistAliases(ctx, artist.ID)
require.NoError(t, err)
assert.Len(t, result, 3) // Includes canonical alias
// Verify aliases were retrieved
expectedAliases := []string{"Alias Artist", "Alias1", "Alias2"}
for _, alias := range expectedAliases {
found := false
for _, res := range result {
if res.Alias == alias {
found = true
break
}
}
assert.True(t, found, "expected alias '%s' to be retrieved", alias)
}
truncateTestData(t)
}
func TestDeleteArtistAlias(t *testing.T) {
ctx := context.Background()
// Insert test artist
artist, err := store.SaveArtist(ctx, db.SaveArtistOpts{
Name: "Alias Artist",
Aliases: []string{"Alias1", "Alias2"},
})
require.NoError(t, err)
// Delete one alias
err = store.DeleteArtistAlias(ctx, artist.ID, "Alias1")
require.NoError(t, err)
// Verify alias was deleted
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_aliases
WHERE artist_id = $1 AND alias = $2
)`, artist.ID, "Alias1")
require.NoError(t, err)
assert.False(t, exists, "expected alias to be deleted")
// Verify other alias still exists
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_aliases
WHERE artist_id = $1 AND alias = $2
)`, artist.ID, "Alias2")
require.NoError(t, err)
assert.True(t, exists, "expected alias to still exist")
truncateTestData(t)
}
func TestDeleteArtist(t *testing.T) {
ctx := context.Background()
// set up a lot of test data, 4 artists, 4 albums, 4 tracks, 10 listens
testDataForTopItems(t)
// Delete the artist
err := store.DeleteArtist(ctx, 1)
require.NoError(t, err)
// Verify artist was deleted
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artists
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected artist to be deleted")
// Verify artist's release was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM releases
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected artist's release to be deleted")
// Verify artist's track was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM tracks
WHERE id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected artist's tracks to be deleted")
// Verify artist's listens was deleted
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM listens
WHERE track_id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected artist's listens to be deleted")
truncateTestData(t)
}

View file

@ -0,0 +1,70 @@
package psql
import (
"context"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/repository"
)
func (p *Psql) CountListens(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountListens(ctx, repository.CountListensParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}
func (p *Psql) CountTracks(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountTopTracks(ctx, repository.CountTopTracksParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}
func (p *Psql) CountAlbums(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountTopReleases(ctx, repository.CountTopReleasesParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}
func (p *Psql) CountArtists(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountTopArtists(ctx, repository.CountTopArtistsParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}
func (p *Psql) CountTimeListened(ctx context.Context, period db.Period) (int64, error) {
t2 := time.Now()
t1 := db.StartTimeFromPeriod(period)
count, err := p.q.CountTimeListened(ctx, repository.CountTimeListenedParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return 0, err
}
return count, nil
}

View file

@ -0,0 +1,76 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCountListens(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountListens
period := db.PeriodWeek
count, err := store.CountListens(ctx, period)
require.NoError(t, err)
assert.Equal(t, int64(1), count, "expected listens count to match inserted data")
truncateTestData(t)
}
func TestCountTracks(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountTracks
period := db.PeriodMonth
count, err := store.CountTracks(ctx, period)
require.NoError(t, err)
assert.Equal(t, int64(2), count, "expected tracks count to match inserted data")
truncateTestData(t)
}
func TestCountAlbums(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountAlbums
period := db.PeriodYear
count, err := store.CountAlbums(ctx, period)
require.NoError(t, err)
assert.Equal(t, int64(3), count, "expected albums count to match inserted data")
truncateTestData(t)
}
func TestCountArtists(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountArtists
period := db.PeriodAllTime
count, err := store.CountArtists(ctx, period)
require.NoError(t, err)
assert.Equal(t, int64(4), count, "expected artists count to match inserted data")
truncateTestData(t)
}
func TestCountTimeListened(t *testing.T) {
ctx := context.Background()
testDataForTopItems(t)
// Test CountTimeListened
period := db.PeriodMonth
count, err := store.CountTimeListened(ctx, period)
require.NoError(t, err)
// 3 listens in past month, each 100 seconds
assert.Equal(t, int64(300), count, "expected total time listened to match inserted data")
truncateTestData(t)
}

View file

@ -0,0 +1,74 @@
package psql
import (
"context"
"encoding/json"
"errors"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
func (d *Psql) ImageHasAssociation(ctx context.Context, image uuid.UUID) (bool, error) {
_, err := d.q.GetReleaseByImageID(ctx, &image)
if err == nil {
return true, err
} else if !errors.Is(err, pgx.ErrNoRows) {
return false, err
}
_, err = d.q.GetArtistByImage(ctx, &image)
if err == nil {
return true, err
} else if !errors.Is(err, pgx.ErrNoRows) {
return false, err
}
return false, nil
}
func (d *Psql) GetImageSource(ctx context.Context, image uuid.UUID) (string, error) {
r, err := d.q.GetReleaseByImageID(ctx, &image)
if err == nil {
return r.ImageSource.String, err
} else if !errors.Is(err, pgx.ErrNoRows) {
return "", err
}
rr, err := d.q.GetArtistByImage(ctx, &image)
if err == nil {
return rr.ImageSource.String, err
} else if !errors.Is(err, pgx.ErrNoRows) {
return "", err
}
return "", nil
}
func (d *Psql) AlbumsWithoutImages(ctx context.Context, from int32) ([]*models.Album, error) {
l := logger.FromContext(ctx)
rows, err := d.q.GetReleasesWithoutImages(ctx, repository.GetReleasesWithoutImagesParams{
Limit: 20,
ID: from,
})
if err != nil {
return nil, err
}
albums := make([]*models.Album, len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", row.ID)
artists = nil
}
albums[i] = &models.Album{
ID: row.ID,
Image: row.Image,
Title: row.Title,
MbzID: row.MusicBrainzID,
VariousArtists: row.VariousArtists,
Artists: artists,
}
}
return albums, nil
}

View file

@ -0,0 +1,106 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/catalog"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestDataForImages(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id, image, image_source)
VALUES ('00000000-0000-0000-0000-000000000001', '11111111-1111-1111-1111-111111111111', 'User Upload'),
('00000000-0000-0000-0000-000000000002', NULL, NULL)`)
require.NoError(t, err)
// Insert artist aliases
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert albums
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id, image, image_source)
VALUES ('22222222-2222-2222-2222-222222222222', '33333333-3333-3333-3333-333333333333', 'Automatic'),
('44444444-4444-4444-4444-444444444444', NULL, NULL)`)
require.NoError(t, err)
// Insert release aliases
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Album One', 'Testing', true),
(2, 'Album Two', 'Testing', true)`)
require.NoError(t, err)
// Associate albums with artists
err = store.Exec(context.Background(),
`INSERT INTO artist_releases (artist_id, release_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
}
func TestImageHasAssociation(t *testing.T) {
ctx := context.Background()
setupTestDataForImages(t)
// Test image with association
imageID := uuid.MustParse("11111111-1111-1111-1111-111111111111")
hasAssociation, err := store.ImageHasAssociation(ctx, imageID)
require.NoError(t, err)
assert.True(t, hasAssociation, "expected image to have an association")
// Test image without association
imageID = uuid.MustParse("55555555-5555-5555-5555-555555555555")
hasAssociation, err = store.ImageHasAssociation(ctx, imageID)
require.NoError(t, err)
assert.False(t, hasAssociation, "expected image to have no association")
truncateTestData(t)
}
func TestGetImageSource(t *testing.T) {
ctx := context.Background()
setupTestDataForImages(t)
// Test image source for an album
imageID := uuid.MustParse("33333333-3333-3333-3333-333333333333")
source, err := store.GetImageSource(ctx, imageID)
require.NoError(t, err)
assert.Equal(t, "Automatic", source, "expected image source to match")
// Test image source for an artist
imageID = uuid.MustParse("11111111-1111-1111-1111-111111111111")
source, err = store.GetImageSource(ctx, imageID)
require.NoError(t, err)
assert.Equal(t, catalog.ImageSourceUserUpload, source, "expected image source to match")
// Test image source for a non-existent image
imageID = uuid.MustParse("55555555-5555-5555-5555-555555555555")
source, err = store.GetImageSource(ctx, imageID)
require.NoError(t, err)
assert.Equal(t, "", source, "expected no image source for non-existent image")
truncateTestData(t)
}
func TestAlbumsWithoutImages(t *testing.T) {
ctx := context.Background()
setupTestDataForImages(t)
// Test albums without images
albums, err := store.AlbumsWithoutImages(ctx, 0)
require.NoError(t, err)
require.Len(t, albums, 1, "expected one album without an image")
assert.Equal(t, "Album Two", albums[0].Title, "expected album title to match")
truncateTestData(t)
}

218
internal/db/psql/listen.go Normal file
View file

@ -0,0 +1,218 @@
package psql
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
)
func (d *Psql) GetListensPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Listen], error) {
l := logger.FromContext(ctx)
offset := (opts.Page - 1) * opts.Limit
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
if err != nil {
return nil, err
}
if opts.Month == 0 && opts.Year == 0 {
// use period, not date range
t2 = time.Now()
t1 = db.StartTimeFromPeriod(opts.Period)
}
if opts.Limit == 0 {
opts.Limit = DefaultItemsPerPage
}
var listens []*models.Listen
var count int64
if opts.TrackID > 0 {
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetLastListensFromTrackPaginated(ctx, repository.GetLastListensFromTrackPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ID: int32(opts.TrackID),
})
if err != nil {
return nil, err
}
listens = make([]*models.Listen, len(rows))
for i, row := range rows {
t := &models.Listen{
Track: models.Track{
Title: row.TrackTitle,
ID: row.TrackID,
},
Time: row.ListenedAt,
}
err = json.Unmarshal(row.Artists, &t.Track.Artists)
if err != nil {
return nil, err
}
listens[i] = t
}
count, err = d.q.CountListensFromTrack(ctx, repository.CountListensFromTrackParams{
ListenedAt: t1,
ListenedAt_2: t2,
TrackID: int32(opts.TrackID),
})
if err != nil {
return nil, err
}
} else if opts.AlbumID > 0 {
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetLastListensFromReleasePaginated(ctx, repository.GetLastListensFromReleasePaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ReleaseID: int32(opts.AlbumID),
})
if err != nil {
return nil, err
}
listens = make([]*models.Listen, len(rows))
for i, row := range rows {
t := &models.Listen{
Track: models.Track{
Title: row.TrackTitle,
ID: row.TrackID,
},
Time: row.ListenedAt,
}
err = json.Unmarshal(row.Artists, &t.Track.Artists)
if err != nil {
return nil, err
}
listens[i] = t
}
count, err = d.q.CountListensFromRelease(ctx, repository.CountListensFromReleaseParams{
ListenedAt: t1,
ListenedAt_2: t2,
ReleaseID: int32(opts.AlbumID),
})
if err != nil {
return nil, err
}
} else if opts.ArtistID > 0 {
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetLastListensFromArtistPaginated(ctx, repository.GetLastListensFromArtistPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ArtistID: int32(opts.ArtistID),
})
if err != nil {
return nil, err
}
listens = make([]*models.Listen, len(rows))
for i, row := range rows {
t := &models.Listen{
Track: models.Track{
Title: row.TrackTitle,
ID: row.TrackID,
},
Time: row.ListenedAt,
}
err = json.Unmarshal(row.Artists, &t.Track.Artists)
if err != nil {
return nil, err
}
listens[i] = t
}
count, err = d.q.CountListensFromArtist(ctx, repository.CountListensFromArtistParams{
ListenedAt: t1,
ListenedAt_2: t2,
ArtistID: int32(opts.ArtistID),
})
if err != nil {
return nil, err
}
} else {
l.Debug().Msgf("Fetching %d listens with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetLastListensPaginated(ctx, repository.GetLastListensPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
listens = make([]*models.Listen, len(rows))
for i, row := range rows {
t := &models.Listen{
Track: models.Track{
Title: row.TrackTitle,
ID: row.TrackID,
},
Time: row.ListenedAt,
}
err = json.Unmarshal(row.Artists, &t.Track.Artists)
if err != nil {
return nil, err
}
listens[i] = t
}
count, err = d.q.CountListens(ctx, repository.CountListensParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Database responded with %d tracks out of a total %d", len(rows), count)
}
return &db.PaginatedResponse[*models.Listen]{
Items: listens,
TotalCount: count,
ItemsPerPage: int32(opts.Limit),
HasNextPage: int64(offset+len(listens)) < count,
CurrentPage: int32(opts.Page),
}, nil
}
func (d *Psql) SaveListen(ctx context.Context, opts db.SaveListenOpts) error {
l := logger.FromContext(ctx)
if opts.TrackID == 0 {
return errors.New("required parameter TrackID missing")
}
if opts.Time.IsZero() {
opts.Time = time.Now()
}
var client *string
if opts.Client != "" {
client = &opts.Client
}
l.Debug().Msgf("Inserting listen for track with id %d at time %v into DB", opts.TrackID, opts.Time)
return d.q.InsertListen(ctx, repository.InsertListenParams{
TrackID: opts.TrackID,
ListenedAt: opts.Time,
UserID: opts.UserID,
Client: client,
})
}
func (d *Psql) DeleteListen(ctx context.Context, trackId int32, listenedAt time.Time) error {
l := logger.FromContext(ctx)
if trackId == 0 {
return errors.New("required parameter 'trackId' missing")
}
l.Debug().Msgf("Deleting listen from track %d at time %s from DB", trackId, listenedAt)
return d.q.DeleteListen(ctx, repository.DeleteListenParams{
TrackID: trackId,
ListenedAt: listenedAt,
})
}

View file

@ -0,0 +1,109 @@
package psql
import (
"context"
"errors"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/repository"
)
func (d *Psql) GetListenActivity(ctx context.Context, opts db.ListenActivityOpts) ([]db.ListenActivityItem, error) {
l := logger.FromContext(ctx)
if opts.Month != 0 && opts.Year == 0 {
return nil, errors.New("year must be specified with month")
}
// Default to range = 12 if not set
if opts.Range == 0 {
opts.Range = db.DefaultRange
}
t1, t2 := db.ListenActivityOptsToTimes(opts)
var listenActivity []db.ListenActivityItem
if opts.AlbumID > 0 {
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for release group %d",
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.AlbumID)
rows, err := d.q.ListenActivityForRelease(ctx, repository.ListenActivityForReleaseParams{
Column1: t1,
Column2: t2,
Column3: stepToInterval(opts.Step),
ReleaseID: opts.AlbumID,
})
if err != nil {
return nil, err
}
listenActivity = make([]db.ListenActivityItem, len(rows))
for i, row := range rows {
t := db.ListenActivityItem{
Start: row.BucketStart,
Listens: row.ListenCount,
}
listenActivity[i] = t
}
l.Debug().Msgf("Database responded with %d steps", len(rows))
} else if opts.ArtistID > 0 {
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for artist %d",
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.ArtistID)
rows, err := d.q.ListenActivityForArtist(ctx, repository.ListenActivityForArtistParams{
Column1: t1,
Column2: t2,
Column3: stepToInterval(opts.Step),
ArtistID: opts.ArtistID,
})
if err != nil {
return nil, err
}
listenActivity = make([]db.ListenActivityItem, len(rows))
for i, row := range rows {
t := db.ListenActivityItem{
Start: row.BucketStart,
Listens: row.ListenCount,
}
listenActivity[i] = t
}
l.Debug().Msgf("Database responded with %d steps", len(rows))
} else if opts.TrackID > 0 {
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v for track %d",
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"), opts.TrackID)
rows, err := d.q.ListenActivityForTrack(ctx, repository.ListenActivityForTrackParams{
Column1: t1,
Column2: t2,
Column3: stepToInterval(opts.Step),
ID: opts.TrackID,
})
if err != nil {
return nil, err
}
listenActivity = make([]db.ListenActivityItem, len(rows))
for i, row := range rows {
t := db.ListenActivityItem{
Start: row.BucketStart,
Listens: row.ListenCount,
}
listenActivity[i] = t
}
l.Debug().Msgf("Database responded with %d steps", len(rows))
} else {
l.Debug().Msgf("Fetching listen activity for %d %s(s) from %v to %v",
opts.Range, opts.Step, t1.Format("Jan 02, 2006 15:04:05"), t2.Format("Jan 02, 2006 15:04:05"))
rows, err := d.q.ListenActivity(ctx, repository.ListenActivityParams{
Column1: t1,
Column2: t2,
Column3: stepToInterval(opts.Step),
})
if err != nil {
return nil, err
}
listenActivity = make([]db.ListenActivityItem, len(rows))
for i, row := range rows {
t := db.ListenActivityItem{
Start: row.BucketStart,
Listens: row.ListenCount,
}
listenActivity[i] = t
}
l.Debug().Msgf("Database responded with %d steps", len(rows))
}
return listenActivity, nil
}

View file

@ -0,0 +1,211 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func flattenListenCounts(items []db.ListenActivityItem) []int64 {
ret := make([]int64, len(items))
for i, v := range items {
ret[i] = v.Listens
}
return ret
}
func TestListenActivity(t *testing.T) {
truncateTestData(t)
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
// Move artist names into artist_aliases
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000011'),
('00000000-0000-0000-0000-000000000022')`)
require.NoError(t, err)
// Move release titles into release_aliases
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Release One', 'Testing', true),
(2, 'Release Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('11111111-1111-1111-1111-111111111111', 1),
('22222222-2222-2222-2222-222222222222', 2)`)
require.NoError(t, err)
// Move track titles into track_aliases
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Associate tracks with artists
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
// Insert listens
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 day'),
(1, 1, NOW() - INTERVAL '2 days'),
(1, 1, NOW() - INTERVAL '1 week 1 day'),
(1, 1, NOW() - INTERVAL '1 month 1 day'),
(1, 1, NOW() - INTERVAL '1 year 1 day'),
(1, 2, NOW() - INTERVAL '1 day'),
(1, 2, NOW() - INTERVAL '2 days'),
(1, 2, NOW() - INTERVAL '1 week 1 day'),
(1, 2, NOW() - INTERVAL '1 month 1 day'),
(1, 2, NOW() - INTERVAL '1 year 1 day')`)
require.NoError(t, err)
ctx := context.Background()
// Test for opts.Step = db.StepDay
activity, err := store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepDay})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 2, 0}, flattenListenCounts(activity))
// Truncate listens table and insert specific dates for testing opts.Step = db.StepMonth
err = store.Exec(context.Background(), `TRUNCATE TABLE listens`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 month'),
(1, 1, NOW() - INTERVAL '2 months'),
(1, 1, NOW() - INTERVAL '3 months'),
(1, 2, NOW() - INTERVAL '1 month'),
(1, 2, NOW() - INTERVAL '2 months')`)
require.NoError(t, err)
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepMonth, Range: 8})
require.NoError(t, err)
require.Len(t, activity, 8)
assert.Equal(t, []int64{0, 0, 0, 0, 1, 2, 2, 0}, flattenListenCounts(activity))
// Truncate listens table and insert specific dates for testing opts.Step = db.StepYear
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 year'),
(1, 1, NOW() - INTERVAL '2 years'),
(1, 2, NOW() - INTERVAL '1 year'),
(1, 2, NOW() - INTERVAL '3 years')`)
require.NoError(t, err)
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{Step: db.StepYear})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 0}, flattenListenCounts(activity))
// Truncate and insert data for a specific month/year
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
require.NoError(t, err)
err = store.Exec(context.Background(), `
INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, DATE '2024-03-10'),
(1, 2, DATE '2024-03-20')`)
require.NoError(t, err)
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
Month: 3,
Year: 2024,
})
require.NoError(t, err)
require.Len(t, activity, 31) // number of days in march
assert.EqualValues(t, 1, activity[8].Listens)
assert.EqualValues(t, 1, activity[18].Listens)
// Truncate and insert listens associated with two different albums
err = store.Exec(context.Background(), `TRUNCATE TABLE listens RESTART IDENTITY`)
require.NoError(t, err)
err = store.Exec(context.Background(), `
INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 day'), (1, 1, NOW() - INTERVAL '2 days'),
(1, 2, NOW() - INTERVAL '1 day')`)
require.NoError(t, err)
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
AlbumID: 1, // Track 1 only
})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0}, flattenListenCounts(activity))
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
TrackID: 1, // Track 1 only
})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0}, flattenListenCounts(activity))
activity, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
ArtistID: 2, // Should only include listens to Track 2
})
require.NoError(t, err)
require.Len(t, activity, db.DefaultRange)
assert.Equal(t, []int64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}, flattenListenCounts(activity))
// month without year is disallowed
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Step: db.StepDay,
Month: 5,
})
require.Error(t, err)
// invalid options
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Year: -10,
})
require.Error(t, err)
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Year: 2025,
Month: -10,
})
require.Error(t, err)
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
Range: -1,
})
require.Error(t, err)
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
AlbumID: -1,
})
require.Error(t, err)
_, err = store.GetListenActivity(ctx, db.ListenActivityOpts{
ArtistID: -1,
})
require.Error(t, err)
}

View file

@ -0,0 +1,219 @@
package psql_test
import (
"context"
"testing"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testDataForListens(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
// Insert artist aliases
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000011'),
('00000000-0000-0000-0000-000000000022')`)
require.NoError(t, err)
// Insert release aliases
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Release One', 'Testing', true),
(2, 'Release Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('11111111-1111-1111-1111-111111111111', 1),
('22222222-2222-2222-2222-222222222222', 2)`)
require.NoError(t, err)
// Insert track aliases
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Insert artist track associations
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (track_id, artist_id)
VALUES (1, 1),
(2, 2)`)
require.NoError(t, err)
}
func TestGetListens(t *testing.T) {
testDataForTopItems(t)
ctx := context.Background()
// Test valid
resp, err := store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 10)
assert.Equal(t, int64(10), resp.TotalCount)
require.Len(t, resp.Items[0].Track.Artists, 1)
require.Len(t, resp.Items[1].Track.Artists, 1)
// ensure tracks are in the right order (time, desc)
assert.Equal(t, "Artist Four", resp.Items[0].Track.Artists[0].Name)
assert.Equal(t, "Artist Three", resp.Items[1].Track.Artists[0].Name)
// Test pagination
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
require.Len(t, resp.Items[0].Track.Artists, 1)
assert.Equal(t, true, resp.HasNextPage)
assert.EqualValues(t, 2, resp.CurrentPage)
assert.EqualValues(t, 1, resp.ItemsPerPage)
assert.EqualValues(t, 10, resp.TotalCount)
assert.Equal(t, "Artist Three", resp.Items[0].Track.Artists[0].Name)
// Test page out of range
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 10, Page: 10, Period: db.PeriodAllTime})
require.NoError(t, err)
assert.Empty(t, resp.Items)
assert.False(t, resp.HasNextPage)
// Test invalid inputs
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
assert.Error(t, err)
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
assert.Error(t, err)
// Test specify period
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
// should default to PeriodDay
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
require.NoError(t, err)
require.Len(t, resp.Items, 6)
assert.Equal(t, int64(6), resp.TotalCount)
// Test filter by artists, releases, and tracks
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, ArtistID: 1})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, TrackID: 3})
require.NoError(t, err)
require.Len(t, resp.Items, 2)
assert.Equal(t, int64(2), resp.TotalCount)
// when both artistID and albumID are specified, artist id is ignored
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2, ArtistID: 1})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
// Test specify dates
testDataAbsoluteListenTimes(t)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Year: 2023})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
resp, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
// invalid, year required with month
_, err = store.GetListensPaginated(ctx, db.GetItemsOpts{Month: 10})
require.Error(t, err)
}
func TestSaveListen(t *testing.T) {
testDataForListens(t)
ctx := context.Background()
// Test SaveListen with valid inputs
err := store.SaveListen(ctx, db.SaveListenOpts{
TrackID: 1,
Time: time.Now(),
UserID: 1,
})
require.NoError(t, err)
// Verify the listen was saved
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM listens
WHERE track_id = $1
)`, 1)
require.NoError(t, err)
assert.True(t, exists, "expected listen to exist")
// Test SaveListen with missing TrackID
err = store.SaveListen(ctx, db.SaveListenOpts{
TrackID: 0,
Time: time.Now(),
})
assert.Error(t, err)
}
func TestDeleteListen(t *testing.T) {
testDataForListens(t)
ctx := context.Background()
err := store.Exec(ctx, `
INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, to_timestamp(1749464138.0))`)
require.NoError(t, err)
err = store.DeleteListen(ctx, 1, time.Unix(1749464138, 0))
require.NoError(t, err)
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM listens
WHERE track_id = $1
)`, 1)
require.NoError(t, err)
assert.False(t, exists, "expected listen to be deleted")
}

109
internal/db/psql/merge.go Normal file
View file

@ -0,0 +1,109 @@
package psql
import (
"context"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/repository"
"github.com/jackc/pgx/v5"
)
func (d *Psql) MergeTracks(ctx context.Context, fromId, toId int32) error {
l := logger.FromContext(ctx)
l.Info().Msgf("Merging track %d into track %d", fromId, toId)
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
err = qtx.UpdateTrackIdForListens(ctx, repository.UpdateTrackIdForListensParams{
TrackID: fromId,
TrackID_2: toId,
})
if err != nil {
return err
}
err = qtx.CleanOrphanedEntries(ctx)
if err != nil {
l.Err(err).Msg("Failed to clean orphaned entries")
return err
}
return tx.Commit(ctx)
}
func (d *Psql) MergeAlbums(ctx context.Context, fromId, toId int32) error {
l := logger.FromContext(ctx)
l.Info().Msgf("Merging album %d into album %d", fromId, toId)
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
err = qtx.UpdateReleaseForAll(ctx, repository.UpdateReleaseForAllParams{
ReleaseID: fromId,
ReleaseID_2: toId,
})
if err != nil {
return err
}
err = qtx.CleanOrphanedEntries(ctx)
if err != nil {
l.Err(err).Msg("Failed to clean orphaned entries")
return err
}
return tx.Commit(ctx)
}
func (d *Psql) MergeArtists(ctx context.Context, fromId, toId int32) error {
l := logger.FromContext(ctx)
l.Info().Msgf("Merging artist %d into artist %d", fromId, toId)
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
err = qtx.DeleteConflictingArtistTracks(ctx, repository.DeleteConflictingArtistTracksParams{
ArtistID: fromId,
ArtistID_2: toId,
})
if err != nil {
l.Err(err).Msg("Failed to delete conflicting artist tracks")
return err
}
err = qtx.DeleteConflictingArtistReleases(ctx, repository.DeleteConflictingArtistReleasesParams{
ArtistID: fromId,
ArtistID_2: toId,
})
if err != nil {
l.Err(err).Msg("Failed to delete conflicting artist releases")
return err
}
err = qtx.UpdateArtistTracks(ctx, repository.UpdateArtistTracksParams{
ArtistID: fromId,
ArtistID_2: toId,
})
if err != nil {
l.Err(err).Msg("Failed to update artist tracks")
return err
}
err = qtx.UpdateArtistReleases(ctx, repository.UpdateArtistReleasesParams{
ArtistID: fromId,
ArtistID_2: toId,
})
if err != nil {
l.Err(err).Msg("Failed to update artist releases")
return err
}
err = qtx.CleanOrphanedEntries(ctx)
if err != nil {
l.Err(err).Msg("Failed to clean orphaned entries")
return err
}
return tx.Commit(ctx)
}

View file

@ -0,0 +1,124 @@
package psql_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestDataForMerge(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert albums
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('11111111-1111-1111-1111-111111111111'),
('22222222-2222-2222-2222-222222222222')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Album One', 'Testing', true),
(2, 'Album Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('33333333-3333-3333-3333-333333333333', 1),
('44444444-4444-4444-4444-444444444444', 2)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Associate artists with albums and tracks
err = store.Exec(context.Background(),
`INSERT INTO artist_releases (artist_id, release_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
// Insert listens
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '1 day'),
(1, 2, NOW() - INTERVAL '2 days')`)
require.NoError(t, err)
}
func TestMergeTracks(t *testing.T) {
ctx := context.Background()
setupTestDataForMerge(t)
// Merge Track 1 into Track 2
err := store.MergeTracks(ctx, 1, 2)
require.NoError(t, err)
// Verify listens are updated
var count int
count, err = store.Count(ctx, `SELECT COUNT(*) FROM listens WHERE track_id = 2`)
require.NoError(t, err)
assert.Equal(t, 2, count, "expected all listens to be merged into Track 2")
truncateTestData(t)
}
func TestMergeAlbums(t *testing.T) {
ctx := context.Background()
setupTestDataForMerge(t)
// Merge Album 1 into Album 2
err := store.MergeAlbums(ctx, 1, 2)
require.NoError(t, err)
// Verify tracks are updated
var count int
count, err = store.Count(ctx, `SELECT COUNT(*) FROM tracks WHERE release_id = 2`)
require.NoError(t, err)
assert.Equal(t, 2, count, "expected all tracks to be merged into Album 2")
truncateTestData(t)
}
func TestMergeArtists(t *testing.T) {
ctx := context.Background()
setupTestDataForMerge(t)
// Merge Artist 1 into Artist 2
err := store.MergeArtists(ctx, 1, 2)
require.NoError(t, err)
// Verify artist associations are updated
var count int
count, err = store.Count(ctx, `SELECT COUNT(*) FROM artist_tracks WHERE artist_id = 2`)
require.NoError(t, err)
assert.Equal(t, 2, count, "expected all tracks to be associated with Artist 2")
count, err = store.Count(ctx, `SELECT COUNT(*) FROM artist_releases WHERE artist_id = 2`)
require.NoError(t, err)
assert.Equal(t, 2, count, "expected all releases to be associated with Artist 2")
truncateTestData(t)
}

119
internal/db/psql/psql.go Normal file
View file

@ -0,0 +1,119 @@
// package psql implements the db.DB interface using psx and a sql generated repository
package psql
import (
"context"
"database/sql"
"fmt"
"path/filepath"
"runtime"
"time"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/repository"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/pressly/goose/v3"
)
const (
DefaultItemsPerPage = 20
)
type Psql struct {
q *repository.Queries
conn *pgxpool.Pool
}
func New() (*Psql, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
config, err := pgxpool.ParseConfig(cfg.DatabaseUrl())
if err != nil {
return nil, fmt.Errorf("failed to parse pgx config: %w", err)
}
config.ConnConfig.ConnectTimeout = 15 * time.Second
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to create pgx pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("database not reachable: %w", err)
}
sqlDB, err := sql.Open("pgx", cfg.DatabaseUrl())
if err != nil {
return nil, fmt.Errorf("failed to open db for migrations: %w", err)
}
_, filename, _, ok := runtime.Caller(0)
if !ok {
return nil, fmt.Errorf("unable to get caller info")
}
migrationsPath := filepath.Join(filepath.Dir(filename), "..", "..", "..", "db", "migrations")
if err := goose.Up(sqlDB, migrationsPath); err != nil {
return nil, fmt.Errorf("goose failed: %w", err)
}
_ = sqlDB.Close()
return &Psql{
q: repository.New(pool),
conn: pool,
}, nil
}
// Not part of the DB interface this package implements. Only used for testing.
func (d *Psql) Exec(ctx context.Context, query string, args ...any) error {
_, err := d.conn.Exec(ctx, query, args...)
return err
}
// Not part of the DB interface this package implements. Only used for testing.
func (d *Psql) RowExists(ctx context.Context, query string, args ...any) (bool, error) {
var exists bool
err := d.conn.QueryRow(ctx, query, args...).Scan(&exists)
return exists, err
}
func (p *Psql) Count(ctx context.Context, query string, args ...any) (count int, err error) {
err = p.conn.QueryRow(ctx, query, args...).Scan(&count)
return
}
// Exposes p.conn.QueryRow. Only used for testing. Not part of the DB interface this package implements.
func (p *Psql) QueryRow(ctx context.Context, query string, args ...any) pgx.Row {
return p.conn.QueryRow(ctx, query, args...)
}
func (d *Psql) Close(ctx context.Context) {
d.conn.Close()
}
func (d *Psql) Ping(ctx context.Context) error {
return d.conn.Ping(ctx)
}
func stepToInterval(p db.StepInterval) pgtype.Interval {
var interval pgtype.Interval
switch p {
case db.StepDay:
interval.Days = 1
case db.StepWeek:
interval.Days = 7
case db.StepMonth:
interval.Months = 1
case db.StepYear:
interval.Months = 12
}
interval.Valid = true
return interval
}

View file

@ -0,0 +1,186 @@
package psql_test
import (
"context"
"fmt"
"log"
"testing"
"github.com/gabehf/koito/internal/cfg"
"github.com/gabehf/koito/internal/db/psql"
_ "github.com/gabehf/koito/testing_init"
"github.com/ory/dockertest/v3"
"github.com/stretchr/testify/require"
)
var store *psql.Psql
func getTestGetenv(resource *dockertest.Resource) func(string) string {
return func(env string) string {
switch env {
case cfg.DATABASE_URL_ENV:
return fmt.Sprintf("postgres://postgres:secret@localhost:%s", resource.GetPort("5432/tcp"))
default:
return ""
}
}
}
func TestMain(m *testing.M) {
// uses a sensible default on windows (tcp/http) and linux/osx (socket)
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not construct pool: %s", err)
}
// uses pool to try to connect to Docker
err = pool.Client.Ping()
if err != nil {
log.Fatalf("Could not connect to Docker: %s", err)
}
// pulls an image, creates a container based on it and runs it
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret"})
if err != nil {
log.Fatalf("Could not start resource: %s", err)
}
err = cfg.Load(getTestGetenv(resource))
if err != nil {
log.Fatalf("Could not load cfg: %s", err)
}
// exponential backoff-retry, because the application in the container might not be ready to accept connections yet
if err := pool.Retry(func() error {
var err error
store, err = psql.New()
if err != nil {
log.Println("Failed to connect to test database, retrying...")
return err
}
return store.Ping(context.Background())
}); err != nil {
log.Fatalf("Could not connect to database: %s", err)
}
// as of go1.15 testing.M returns the exit code of m.Run(), so it is safe to use defer here
defer func() {
if err := pool.Purge(resource); err != nil {
log.Fatalf("Could not purge resource: %s", err)
}
}()
// insert a user into the db with id 1 to use for tests
err = store.Exec(context.Background(), `INSERT INTO users (username, password) VALUES ('test', DECODE('abc123', 'hex'))`)
if err != nil {
log.Fatalf("Failed to insert test user: %v", err)
}
m.Run()
}
func testDataForTopItems(t *testing.T) {
truncateTestData(t)
// artist 1 has most listens older than 1 year
// artist 2 has most listens older than 1 month
// artist 3 has most listens older than 1 week
// artist 4 has least listens
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002'),
('00000000-0000-0000-0000-000000000003'),
('00000000-0000-0000-0000-000000000004')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true),
(3, 'Artist Three', 'Testing', true),
(4, 'Artist Four', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000011'),
('00000000-0000-0000-0000-000000000022'),
('00000000-0000-0000-0000-000000000033'),
('00000000-0000-0000-0000-000000000044')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Release One', 'Testing', true),
(2, 'Release Two', 'Testing', true),
(3, 'Release Three', 'Testing', true),
(4, 'Release Four', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO artist_releases (release_id, artist_id)
VALUES (1, 1), (2, 2), (3, 3), (4, 4)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id, duration)
VALUES ('11111111-1111-1111-1111-111111111111', 1, 100),
('22222222-2222-2222-2222-222222222222', 2, 100),
('33333333-3333-3333-3333-333333333333', 3, 100),
('44444444-4444-4444-4444-444444444444', 4, 100)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true),
(3, 'Track Three', 'Testing', true),
(4, 'Track Four', 'Testing', true)`)
require.NoError(t, err)
// Associate tracks with artists
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2), (3, 3), (4, 4)`)
require.NoError(t, err)
// Insert listens
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, NOW() - INTERVAL '2 years 1 day'),
(1, 1, NOW() - INTERVAL '2 years 2 days'),
(1, 1, NOW() - INTERVAL '2 years 3 days'),
(1, 1, NOW() - INTERVAL '2 years 4 days'),
(1, 2, NOW() - INTERVAL '2 months 1 day'),
(1, 2, NOW() - INTERVAL '2 months 2 days'),
(1, 2, NOW() - INTERVAL '2 months 3 days'),
(1, 3, NOW() - INTERVAL '2 weeks'),
(1, 3, NOW() - INTERVAL '2 weeks 1 day'),
(1, 4, NOW() - INTERVAL '2 days')`)
require.NoError(t, err)
}
func testDataAbsoluteListenTimes(t *testing.T) {
err := store.Exec(context.Background(),
`TRUNCATE listens`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO listens (user_id, track_id, listened_at)
VALUES (1, 1, '2023-06-22 19:11:25-07'),
(1, 1, '2023-06-22 19:12:25-07'),
(1, 1, '2023-06-22 19:13:25-07'),
(1, 1, '2023-06-22 19:14:25-07'),
(1, 2, '2024-06-22 19:15:25-07'),
(1, 2, '2024-06-22 19:16:25-07'),
(1, 2, '2024-06-22 19:17:25-07'),
(1, 3, '2024-10-02 19:18:25-07'),
(1, 3, '2024-10-02 19:19:25-07'),
(1, 4, '2025-05-16 19:20:25-07')`)
require.NoError(t, err)
}

151
internal/db/psql/search.go Normal file
View file

@ -0,0 +1,151 @@
package psql
import (
"context"
"encoding/json"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/jackc/pgx/v5/pgtype"
)
const searchItemLimit = 5
const substringSearchLength = 6
func (d *Psql) SearchArtists(ctx context.Context, q string) ([]*models.Artist, error) {
if len(q) < substringSearchLength {
rows, err := d.q.SearchArtistsBySubstring(ctx, repository.SearchArtistsBySubstringParams{
Column1: pgtype.Text{String: q, Valid: true},
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Artist, len(rows))
for i, row := range rows {
ret[i] = &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Image: row.Image,
}
}
return ret, nil
} else {
rows, err := d.q.SearchArtists(ctx, repository.SearchArtistsParams{
Similarity: q,
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Artist, len(rows))
for i, row := range rows {
ret[i] = &models.Artist{
ID: row.ID,
MbzID: row.MusicBrainzID,
Name: row.Name,
Image: row.Image,
}
}
return ret, nil
}
}
func (d *Psql) SearchAlbums(ctx context.Context, q string) ([]*models.Album, error) {
if len(q) < substringSearchLength {
rows, err := d.q.SearchReleasesBySubstring(ctx, repository.SearchReleasesBySubstringParams{
Column1: pgtype.Text{String: q, Valid: true},
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Album, len(rows))
for i, row := range rows {
ret[i] = &models.Album{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
VariousArtists: row.VariousArtists,
Image: row.Image,
}
err = json.Unmarshal(row.Artists, &ret[i].Artists)
if err != nil {
return nil, err
}
}
return ret, nil
} else {
rows, err := d.q.SearchReleases(ctx, repository.SearchReleasesParams{
Similarity: q,
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Album, len(rows))
for i, row := range rows {
ret[i] = &models.Album{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
VariousArtists: row.VariousArtists,
Image: row.Image,
}
err = json.Unmarshal(row.Artists, &ret[i].Artists)
if err != nil {
return nil, err
}
}
return ret, nil
}
}
func (d *Psql) SearchTracks(ctx context.Context, q string) ([]*models.Track, error) {
if len(q) < substringSearchLength {
rows, err := d.q.SearchTracksBySubstring(ctx, repository.SearchTracksBySubstringParams{
Column1: pgtype.Text{String: q, Valid: true},
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Track, len(rows))
for i, row := range rows {
ret[i] = &models.Track{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
Image: row.Image,
}
err = json.Unmarshal(row.Artists, &ret[i].Artists)
if err != nil {
return nil, err
}
}
return ret, nil
} else {
rows, err := d.q.SearchTracks(ctx, repository.SearchTracksParams{
Similarity: q,
Limit: searchItemLimit,
})
if err != nil {
return nil, err
}
ret := make([]*models.Track, len(rows))
for i, row := range rows {
ret[i] = &models.Track{
ID: row.ID,
MbzID: row.MusicBrainzID,
Title: row.Title,
Image: row.Image,
}
err = json.Unmarshal(row.Artists, &ret[i].Artists)
if err != nil {
return nil, err
}
}
return ret, nil
}
}

View file

@ -0,0 +1,116 @@
package psql_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupTestDataForSearch(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One With A Long Name', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert albums
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id, various_artists)
VALUES ('11111111-1111-1111-1111-111111111111', false),
('22222222-2222-2222-2222-222222222222', true)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Album One With A Long Name', 'Testing', true),
(2, 'Album Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('33333333-3333-3333-3333-333333333333', 1),
('44444444-4444-4444-4444-444444444444', 2)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One With A Long Name', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Associate artists with albums and tracks
err = store.Exec(context.Background(),
`INSERT INTO artist_releases (artist_id, release_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
}
func TestSearchArtists(t *testing.T) {
ctx := context.Background()
setupTestDataForSearch(t)
// Search for "Artist One With A Long Name"
results, err := store.SearchArtists(ctx, "Artist One With A Long Name")
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, "Artist One With A Long Name", results[0].Name)
// Search for substring "Artist"
results, err = store.SearchArtists(ctx, "Arti")
require.NoError(t, err)
require.Len(t, results, 2)
truncateTestData(t)
}
func TestSearchAlbums(t *testing.T) {
ctx := context.Background()
setupTestDataForSearch(t)
// Search for "Album One With A Long Name"
results, err := store.SearchAlbums(ctx, "Album One With A Long Name")
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, "Album One With A Long Name", results[0].Title)
// Search for substring "Album"
results, err = store.SearchAlbums(ctx, "Albu")
require.NoError(t, err)
require.Len(t, results, 2)
assert.NotNil(t, results[0].Artists)
truncateTestData(t)
}
func TestSearchTracks(t *testing.T) {
ctx := context.Background()
setupTestDataForSearch(t)
// Search for "Track One With A Long Name"
results, err := store.SearchTracks(ctx, "Track One With A Long Name")
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, "Track One With A Long Name", results[0].Title)
// Search for substring "Track"
results, err = store.SearchTracks(ctx, "Trac")
require.NoError(t, err)
require.Len(t, results, 2)
assert.NotNil(t, results[0].Artists)
truncateTestData(t)
}

View file

@ -0,0 +1,59 @@
package psql
import (
"context"
"errors"
"time"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
func (d *Psql) SaveSession(ctx context.Context, userID int32, expiresAt time.Time, persistent bool) (*models.Session, error) {
session, err := d.q.InsertSession(ctx, repository.InsertSessionParams{
ID: uuid.New(),
UserID: userID,
ExpiresAt: expiresAt,
Persistent: persistent,
})
if err != nil {
return nil, err
}
return &models.Session{
ID: session.ID,
UserID: session.UserID,
CreatedAt: session.CreatedAt,
ExpiresAt: session.ExpiresAt,
Persistent: session.Persistent,
}, nil
}
func (d *Psql) RefreshSession(ctx context.Context, sessionId uuid.UUID, expiresAt time.Time) error {
return d.q.UpdateSessionExpiry(ctx, repository.UpdateSessionExpiryParams{
ID: sessionId,
ExpiresAt: expiresAt,
})
}
func (d *Psql) DeleteSession(ctx context.Context, sessionId uuid.UUID) error {
return d.q.DeleteSession(ctx, sessionId)
}
// Returns nil, nil when no database entries are found
func (d *Psql) GetUserBySession(ctx context.Context, sessionId uuid.UUID) (*models.User, error) {
row, err := d.q.GetUserBySession(ctx, sessionId)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &models.User{
ID: row.ID,
Username: row.Username,
Password: row.Password,
Role: models.UserRole(row.Role),
}, nil
}

View file

@ -0,0 +1,101 @@
package psql_test
import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func truncateTestDataForSessions(t *testing.T) {
err := store.Exec(context.Background(),
`TRUNCATE
sessions
RESTART IDENTITY CASCADE`,
)
require.NoError(t, err)
}
func TestSaveSession(t *testing.T) {
ctx := context.Background()
// Save a session for the user
expiresAt := time.Now().Add(24 * time.Hour).UTC()
session, err := store.SaveSession(ctx, 1, expiresAt, true)
require.NoError(t, err)
require.NotNil(t, session)
assert.Equal(t, int32(1), session.UserID)
assert.Equal(t, true, session.Persistent)
assert.WithinDuration(t, expiresAt, session.ExpiresAt, time.Second)
truncateTestDataForSessions(t)
}
func TestRefreshSession(t *testing.T) {
ctx := context.Background()
// Save a session first
expiresAt := time.Now().Add(-1 * time.Minute)
session, err := store.SaveSession(ctx, 1, expiresAt, true)
require.NoError(t, err)
// Refresh the session expiry
newExpiresAt := time.Now().Add(48 * time.Hour)
err = store.RefreshSession(ctx, session.ID, newExpiresAt)
require.NoError(t, err)
// Can only retrieve a session with an expiresAt > time.Now()
_, err = store.GetUserBySession(ctx, session.ID)
require.NoError(t, err)
truncateTestDataForSessions(t)
}
func TestDeleteSession(t *testing.T) {
ctx := context.Background()
// Save a session first
expiresAt := time.Now().Add(24 * time.Hour)
session, err := store.SaveSession(ctx, 1, expiresAt, true)
require.NoError(t, err)
// Delete the session
err = store.DeleteSession(ctx, session.ID)
require.NoError(t, err)
// Verify the session was deleted
var count int
count, err = store.Count(ctx, `SELECT COUNT(*) FROM sessions WHERE id = $1`, session.ID)
require.NoError(t, err)
assert.Equal(t, 0, count)
truncateTestDataForSessions(t)
}
func TestGetUserBySession(t *testing.T) {
ctx := context.Background()
// Save a session first
expiresAt := time.Now().Add(24 * time.Hour)
session, err := store.SaveSession(ctx, 1, expiresAt, true)
require.NoError(t, err)
// Get the user by session
user, err := store.GetUserBySession(ctx, session.ID)
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, int32(1), user.ID)
assert.Equal(t, "test", user.Username)
assert.Equal(t, []uint8([]byte{0xab, 0xc1, 0x23}), user.Password)
assert.Equal(t, "user", string(user.Role))
// Test for a non-existent session
nonExistentSessionID := uuid.New()
user, err = store.GetUserBySession(ctx, nonExistentSessionID)
require.NoError(t, err)
assert.Nil(t, user)
truncateTestDataForSessions(t)
}

View file

@ -0,0 +1,119 @@
package psql
import (
"context"
"encoding/json"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
)
func (d *Psql) GetTopAlbumsPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Album], error) {
l := logger.FromContext(ctx)
offset := (opts.Page - 1) * opts.Limit
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
if err != nil {
return nil, err
}
if opts.Month == 0 && opts.Year == 0 {
// use period, not date range
t2 = time.Now()
t1 = db.StartTimeFromPeriod(opts.Period)
}
if opts.Limit == 0 {
opts.Limit = DefaultItemsPerPage
}
var rgs []*models.Album
var count int64
if opts.ArtistID != 0 {
l.Debug().Msgf("Fetching top %d albums from artist id %d with period %s on page %d from range %v to %v",
opts.Limit, opts.ArtistID, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopReleasesFromArtist(ctx, repository.GetTopReleasesFromArtistParams{
ArtistID: int32(opts.ArtistID),
Limit: int32(opts.Limit),
Offset: int32(offset),
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
rgs = make([]*models.Album, len(rows))
l.Debug().Msgf("Database responded with %d items", len(rows))
for i, v := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(v.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", v.ID)
artists = nil
}
rgs[i] = &models.Album{
ID: v.ID,
MbzID: v.MusicBrainzID,
Title: v.Title,
Image: v.Image,
Artists: artists,
VariousArtists: v.VariousArtists,
ListenCount: v.ListenCount,
}
}
count, err = d.q.CountReleasesFromArtist(ctx, int32(opts.ArtistID))
if err != nil {
return nil, err
}
} else {
l.Debug().Msgf("Fetching top %d albums with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopReleasesPaginated(ctx, repository.GetTopReleasesPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
rgs = make([]*models.Album, len(rows))
l.Debug().Msgf("Database responded with %d items", len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for release group with id %d", row.ID)
artists = nil
}
t := &models.Album{
Title: row.Title,
MbzID: row.MusicBrainzID,
ID: row.ID,
Image: row.Image,
Artists: artists,
VariousArtists: row.VariousArtists,
ListenCount: row.ListenCount,
}
rgs[i] = t
}
count, err = d.q.CountTopReleases(ctx, repository.CountTopReleasesParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Database responded with %d albums out of a total %d", len(rows), count)
}
return &db.PaginatedResponse[*models.Album]{
Items: rgs,
TotalCount: count,
ItemsPerPage: int32(opts.Limit),
HasNextPage: int64(offset+len(rgs)) < count,
CurrentPage: int32(opts.Page),
}, nil
}

View file

@ -0,0 +1,103 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetTopAlbumsPaginated(t *testing.T) {
testDataForTopItems(t)
ctx := context.Background()
// Test valid
resp, err := store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
assert.Equal(t, "Release One", resp.Items[0].Title)
assert.Equal(t, "Release Two", resp.Items[1].Title)
assert.Equal(t, "Release Three", resp.Items[2].Title)
assert.Equal(t, "Release Four", resp.Items[3].Title)
// Test pagination
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, "Release Two", resp.Items[0].Title)
// Test page out of range
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Empty(t, resp.Items)
assert.False(t, resp.HasNextPage)
// Test invalid inputs
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
assert.Error(t, err)
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
assert.Error(t, err)
// Test specify period
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
// should default to PeriodDay
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Release Four", resp.Items[0].Title)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
require.NoError(t, err)
require.Len(t, resp.Items, 2)
assert.Equal(t, int64(2), resp.TotalCount)
assert.Equal(t, "Release Three", resp.Items[0].Title)
assert.Equal(t, "Release Four", resp.Items[1].Title)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
assert.Equal(t, "Release Two", resp.Items[0].Title)
assert.Equal(t, "Release Three", resp.Items[1].Title)
assert.Equal(t, "Release Four", resp.Items[2].Title)
// test specific artist
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear, ArtistID: 2})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Release Two", resp.Items[0].Title)
// Test specify dates
testDataAbsoluteListenTimes(t)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Year: 2023})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Release One", resp.Items[0].Title)
resp, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Release Two", resp.Items[0].Title)
// invalid, year required with month
_, err = store.GetTopAlbumsPaginated(ctx, db.GetItemsOpts{Month: 10})
require.Error(t, err)
}

View file

@ -0,0 +1,67 @@
package psql
import (
"context"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
)
func (d *Psql) GetTopArtistsPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Artist], error) {
l := logger.FromContext(ctx)
offset := (opts.Page - 1) * opts.Limit
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
if err != nil {
return nil, err
}
if opts.Month == 0 && opts.Year == 0 {
// use period, not date range
t2 = time.Now()
t1 = db.StartTimeFromPeriod(opts.Period)
}
if opts.Limit == 0 {
opts.Limit = DefaultItemsPerPage
}
l.Debug().Msgf("Fetching top %d artists with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopArtistsPaginated(ctx, repository.GetTopArtistsPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
rgs := make([]*models.Artist, len(rows))
for i, row := range rows {
t := &models.Artist{
Name: row.Name,
MbzID: row.MusicBrainzID,
ID: row.ID,
Image: row.Image,
ListenCount: row.ListenCount,
}
rgs[i] = t
}
count, err := d.q.CountTopArtists(ctx, repository.CountTopArtistsParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Database responded with %d artists out of a total %d", len(rows), count)
return &db.PaginatedResponse[*models.Artist]{
Items: rgs,
TotalCount: count,
ItemsPerPage: int32(opts.Limit),
HasNextPage: int64(offset+len(rgs)) < count,
CurrentPage: int32(opts.Page),
}, nil
}

View file

@ -0,0 +1,96 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetTopArtistsPaginated(t *testing.T) {
testDataForTopItems(t)
ctx := context.Background()
// Test valid
resp, err := store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
assert.Equal(t, "Artist One", resp.Items[0].Name)
assert.Equal(t, "Artist Two", resp.Items[1].Name)
assert.Equal(t, "Artist Three", resp.Items[2].Name)
assert.Equal(t, "Artist Four", resp.Items[3].Name)
// Test pagination
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, "Artist Two", resp.Items[0].Name)
// Test page out of range
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
require.NoError(t, err)
assert.Empty(t, resp.Items)
assert.False(t, resp.HasNextPage)
// Test invalid inputs
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
assert.Error(t, err)
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
assert.Error(t, err)
// Test specify period
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
// should default to PeriodDay
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Artist Four", resp.Items[0].Name)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
require.NoError(t, err)
require.Len(t, resp.Items, 2)
assert.Equal(t, int64(2), resp.TotalCount)
assert.Equal(t, "Artist Three", resp.Items[0].Name)
assert.Equal(t, "Artist Four", resp.Items[1].Name)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
assert.Equal(t, "Artist Two", resp.Items[0].Name)
assert.Equal(t, "Artist Three", resp.Items[1].Name)
assert.Equal(t, "Artist Four", resp.Items[2].Name)
// Test specify dates
testDataAbsoluteListenTimes(t)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Year: 2023})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Artist One", resp.Items[0].Name)
resp, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Artist Two", resp.Items[0].Name)
// invalid, year required with month
_, err = store.GetTopArtistsPaginated(ctx, db.GetItemsOpts{Month: 10})
require.Error(t, err)
}

View file

@ -0,0 +1,160 @@
package psql
import (
"context"
"encoding/json"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
)
func (d *Psql) GetTopTracksPaginated(ctx context.Context, opts db.GetItemsOpts) (*db.PaginatedResponse[*models.Track], error) {
l := logger.FromContext(ctx)
offset := (opts.Page - 1) * opts.Limit
t1, t2, err := utils.DateRange(opts.Week, opts.Month, opts.Year)
if err != nil {
return nil, err
}
if opts.Month == 0 && opts.Year == 0 {
// use period, not date range
t2 = time.Now()
t1 = db.StartTimeFromPeriod(opts.Period)
}
if opts.Limit == 0 {
opts.Limit = DefaultItemsPerPage
}
var tracks []*models.Track
var count int64
if opts.AlbumID > 0 {
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopTracksInReleasePaginated(ctx, repository.GetTopTracksInReleasePaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ReleaseID: int32(opts.AlbumID),
})
if err != nil {
return nil, err
}
tracks = make([]*models.Track, len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
artists = nil
}
t := &models.Track{
Title: row.Title,
MbzID: row.MusicBrainzID,
ID: row.ID,
ListenCount: row.ListenCount,
Image: row.Image,
AlbumID: row.ReleaseID,
Artists: artists,
}
tracks[i] = t
}
count, err = d.q.CountTopTracksByRelease(ctx, repository.CountTopTracksByReleaseParams{
ListenedAt: t1,
ListenedAt_2: t2,
ReleaseID: int32(opts.AlbumID),
})
if err != nil {
return nil, err
}
} else if opts.ArtistID > 0 {
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopTracksByArtistPaginated(ctx, repository.GetTopTracksByArtistPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
ArtistID: int32(opts.ArtistID),
})
if err != nil {
return nil, err
}
tracks = make([]*models.Track, len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
artists = nil
}
t := &models.Track{
Title: row.Title,
MbzID: row.MusicBrainzID,
ID: row.ID,
Image: row.Image,
ListenCount: row.ListenCount,
AlbumID: row.ReleaseID,
Artists: artists,
}
tracks[i] = t
}
count, err = d.q.CountTopTracksByArtist(ctx, repository.CountTopTracksByArtistParams{
ListenedAt: t1,
ListenedAt_2: t2,
ArtistID: int32(opts.ArtistID),
})
if err != nil {
return nil, err
}
} else {
l.Debug().Msgf("Fetching top %d tracks with period %s on page %d from range %v to %v",
opts.Limit, opts.Period, opts.Page, t1.Format("Jan 02, 2006"), t2.Format("Jan 02, 2006"))
rows, err := d.q.GetTopTracksPaginated(ctx, repository.GetTopTracksPaginatedParams{
ListenedAt: t1,
ListenedAt_2: t2,
Limit: int32(opts.Limit),
Offset: int32(offset),
})
if err != nil {
return nil, err
}
tracks = make([]*models.Track, len(rows))
for i, row := range rows {
artists := make([]models.SimpleArtist, 0)
err = json.Unmarshal(row.Artists, &artists)
if err != nil {
l.Err(err).Msgf("Error unmarshalling artists for track with id %d", row.ID)
artists = nil
}
t := &models.Track{
Title: row.Title,
MbzID: row.MusicBrainzID,
ID: row.ID,
Image: row.Image,
ListenCount: row.ListenCount,
AlbumID: row.ReleaseID,
Artists: artists,
}
tracks[i] = t
}
count, err = d.q.CountTopTracks(ctx, repository.CountTopTracksParams{
ListenedAt: t1,
ListenedAt_2: t2,
})
if err != nil {
return nil, err
}
l.Debug().Msgf("Database responded with %d tracks out of a total %d", len(rows), count)
}
return &db.PaginatedResponse[*models.Track]{
Items: tracks,
TotalCount: count,
ItemsPerPage: int32(opts.Limit),
HasNextPage: int64(offset+len(tracks)) < count,
CurrentPage: int32(opts.Page),
}, nil
}

View file

@ -0,0 +1,118 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetTopTracksPaginated(t *testing.T) {
testDataForTopItems(t)
ctx := context.Background()
// Test valid
resp, err := store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 4)
assert.Equal(t, int64(4), resp.TotalCount)
assert.Equal(t, "Track One", resp.Items[0].Title)
assert.Equal(t, "Track Two", resp.Items[1].Title)
assert.Equal(t, "Track Three", resp.Items[2].Title)
assert.Equal(t, "Track Four", resp.Items[3].Title)
// ensure artists are included
require.Len(t, resp.Items[0].Artists, 1)
assert.Equal(t, "Artist One", resp.Items[0].Artists[0].Name)
// Test pagination
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 2, Period: db.PeriodAllTime})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, "Track Two", resp.Items[0].Title)
// Test page out of range
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: 10, Period: db.PeriodAllTime})
require.NoError(t, err)
assert.Empty(t, resp.Items)
assert.False(t, resp.HasNextPage)
// Test invalid inputs
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: -1, Page: 0})
assert.Error(t, err)
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Limit: 1, Page: -1})
assert.Error(t, err)
// Test specify period
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodDay})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
// should default to PeriodDay
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{})
require.NoError(t, err)
require.Len(t, resp.Items, 0) // empty
assert.Equal(t, int64(0), resp.TotalCount)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodWeek})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track Four", resp.Items[0].Title)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodMonth})
require.NoError(t, err)
require.Len(t, resp.Items, 2)
assert.Equal(t, int64(2), resp.TotalCount)
assert.Equal(t, "Track Three", resp.Items[0].Title)
assert.Equal(t, "Track Four", resp.Items[1].Title)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodYear})
require.NoError(t, err)
require.Len(t, resp.Items, 3)
assert.Equal(t, int64(3), resp.TotalCount)
assert.Equal(t, "Track Two", resp.Items[0].Title)
assert.Equal(t, "Track Three", resp.Items[1].Title)
assert.Equal(t, "Track Four", resp.Items[2].Title)
// Test filter by artists and releases
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, ArtistID: 1})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track One", resp.Items[0].Title)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track Two", resp.Items[0].Title)
// when both artistID and albumID are specified, artist id is ignored
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Period: db.PeriodAllTime, AlbumID: 2, ArtistID: 1})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track Two", resp.Items[0].Title)
// Test specify dates
testDataAbsoluteListenTimes(t)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Year: 2023})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track One", resp.Items[0].Title)
resp, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Month: 6, Year: 2024})
require.NoError(t, err)
require.Len(t, resp.Items, 1)
assert.Equal(t, int64(1), resp.TotalCount)
assert.Equal(t, "Track Two", resp.Items[0].Title)
// invalid, year required with month
_, err = store.GetTopTracksPaginated(ctx, db.GetItemsOpts{Month: 10})
require.Error(t, err)
}

298
internal/db/psql/track.go Normal file
View file

@ -0,0 +1,298 @@
package psql
import (
"context"
"errors"
"strings"
"time"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/gabehf/koito/internal/utils"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
)
func (d *Psql) GetTrack(ctx context.Context, opts db.GetTrackOpts) (*models.Track, error) {
l := logger.FromContext(ctx)
var track models.Track
if opts.ID != 0 {
l.Debug().Msgf("Fetching track from DB with id %d", opts.ID)
t, err := d.q.GetTrack(ctx, opts.ID)
if err != nil {
return nil, err
}
track = models.Track{
ID: t.ID,
MbzID: t.MusicBrainzID,
Title: t.Title,
AlbumID: t.ReleaseID,
Image: t.Image,
Duration: t.Duration,
}
} else if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Fetching track from DB with MusicBrainz ID %s", opts.MusicBrainzID)
t, err := d.q.GetTrackByMbzID(ctx, &opts.MusicBrainzID)
if err != nil {
return nil, err
}
track = models.Track{
ID: t.ID,
MbzID: t.MusicBrainzID,
Title: t.Title,
AlbumID: t.ReleaseID,
Duration: t.Duration,
}
} else if len(opts.ArtistIDs) > 0 {
l.Debug().Msgf("Fetching track from DB with title '%s' and artist id(s) '%v'", opts.Title, opts.ArtistIDs)
t, err := d.q.GetTrackByTitleAndArtists(ctx, repository.GetTrackByTitleAndArtistsParams{
Title: opts.Title,
Column2: opts.ArtistIDs,
})
if err != nil {
return nil, err
}
track = models.Track{
ID: t.ID,
MbzID: t.MusicBrainzID,
Title: t.Title,
AlbumID: t.ReleaseID,
Duration: t.Duration,
}
} else {
return nil, errors.New("insufficient information to get track")
}
count, err := d.q.CountListensFromTrack(ctx, repository.CountListensFromTrackParams{
ListenedAt: time.Unix(0, 0),
ListenedAt_2: time.Now(),
TrackID: track.ID,
})
if err != nil {
l.Err(err).Msgf("Failed to get listen count for track with id %d", track.ID)
}
track.ListenCount = count
return &track, nil
}
func (d *Psql) SaveTrack(ctx context.Context, opts db.SaveTrackOpts) (*models.Track, error) {
// create track in DB
l := logger.FromContext(ctx)
var insertMbzID *uuid.UUID
if opts.RecordingMbzID != uuid.Nil {
insertMbzID = &opts.RecordingMbzID
}
if len(opts.ArtistIDs) < 1 {
return nil, errors.New("required parameter 'ArtistIDs' missing")
}
for _, aid := range opts.ArtistIDs {
if aid == 0 {
return nil, errors.New("none of 'ArtistIDs' may be 0")
}
}
if opts.AlbumID == 0 {
return nil, errors.New("required parameter 'AlbumID' missing")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return nil, err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
l.Debug().Msgf("Inserting new track '%s' into DB", opts.Title)
trackRow, err := qtx.InsertTrack(ctx, repository.InsertTrackParams{
MusicBrainzID: insertMbzID,
ReleaseID: opts.AlbumID,
})
if err != nil {
return nil, err
}
// insert associated artists
for _, aid := range opts.ArtistIDs {
err = qtx.AssociateArtistToTrack(ctx, repository.AssociateArtistToTrackParams{
ArtistID: aid,
TrackID: trackRow.ID,
})
if err != nil {
return nil, err
}
}
// insert primary alias
err = qtx.InsertTrackAlias(ctx, repository.InsertTrackAliasParams{
TrackID: trackRow.ID,
Alias: opts.Title,
Source: "Canonical",
IsPrimary: true,
})
if err != nil {
return nil, err
}
err = tx.Commit(ctx)
if err != nil {
return nil, err
}
return &models.Track{
ID: trackRow.ID,
MbzID: insertMbzID,
Title: opts.Title,
}, nil
}
func (d *Psql) UpdateTrack(ctx context.Context, opts db.UpdateTrackOpts) error {
l := logger.FromContext(ctx)
if opts.ID == 0 {
return errors.New("track id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
if opts.MusicBrainzID != uuid.Nil {
l.Debug().Msgf("Updating MusicBrainz ID for track %d", opts.ID)
err := qtx.UpdateTrackMbzID(ctx, repository.UpdateTrackMbzIDParams{
ID: opts.ID,
MusicBrainzID: &opts.MusicBrainzID,
})
if err != nil {
return err
}
}
if opts.Duration != 0 {
l.Debug().Msgf("Updating duration for track %d", opts.ID)
err := qtx.UpdateTrackDuration(ctx, repository.UpdateTrackDurationParams{
ID: opts.ID,
Duration: opts.Duration,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) SaveTrackAliases(ctx context.Context, id int32, aliases []string, source string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("track id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
existing, err := qtx.GetAllTrackAliases(ctx, id)
if err != nil {
return err
}
for _, v := range existing {
aliases = append(aliases, v.Alias)
}
utils.Unique(&aliases)
for _, alias := range aliases {
if strings.TrimSpace(alias) == "" {
return errors.New("aliases cannot be blank")
}
err = qtx.InsertTrackAlias(ctx, repository.InsertTrackAliasParams{
Alias: strings.TrimSpace(alias),
TrackID: id,
Source: source,
IsPrimary: false,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) DeleteTrack(ctx context.Context, id int32) error {
return d.q.DeleteTrack(ctx, id)
}
func (d *Psql) DeleteTrackAlias(ctx context.Context, id int32, alias string) error {
return d.q.DeleteTrackAlias(ctx, repository.DeleteTrackAliasParams{
TrackID: id,
Alias: alias,
})
}
func (d *Psql) GetAllTrackAliases(ctx context.Context, id int32) ([]models.Alias, error) {
rows, err := d.q.GetAllTrackAliases(ctx, id)
if err != nil {
return nil, err
}
aliases := make([]models.Alias, len(rows))
for i, row := range rows {
aliases[i] = models.Alias{
ID: id,
Alias: row.Alias,
Source: row.Source,
Primary: row.IsPrimary,
}
}
return aliases, nil
}
func (d *Psql) SetPrimaryTrackAlias(ctx context.Context, id int32, alias string) error {
l := logger.FromContext(ctx)
if id == 0 {
return errors.New("artist id not specified")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
// get all aliases
aliases, err := qtx.GetAllTrackAliases(ctx, id)
if err != nil {
return err
}
primary := ""
exists := false
for _, v := range aliases {
if v.Alias == alias {
exists = true
}
if v.IsPrimary {
primary = v.Alias
}
}
if primary == alias {
// no-op rename
return nil
}
if !exists {
return errors.New("alias does not exist")
}
err = qtx.SetTrackAliasPrimaryStatus(ctx, repository.SetTrackAliasPrimaryStatusParams{
TrackID: id,
Alias: alias,
IsPrimary: true,
})
if err != nil {
return err
}
err = qtx.SetTrackAliasPrimaryStatus(ctx, repository.SetTrackAliasPrimaryStatusParams{
TrackID: id,
Alias: primary,
IsPrimary: false,
})
if err != nil {
return err
}
return tx.Commit(ctx)
}

View file

@ -0,0 +1,213 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testDataForTracks(t *testing.T) {
truncateTestData(t)
// Insert artists
err := store.Exec(context.Background(),
`INSERT INTO artists (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000001'),
('00000000-0000-0000-0000-000000000002')`)
require.NoError(t, err)
// Insert artist aliases
err = store.Exec(context.Background(),
`INSERT INTO artist_aliases (artist_id, alias, source, is_primary)
VALUES (1, 'Artist One', 'Testing', true),
(2, 'Artist Two', 'Testing', true)`)
require.NoError(t, err)
// Insert release groups
err = store.Exec(context.Background(),
`INSERT INTO releases (musicbrainz_id)
VALUES ('00000000-0000-0000-0000-000000000011'),
('00000000-0000-0000-0000-000000000022')`)
require.NoError(t, err)
// Insert release aliases
err = store.Exec(context.Background(),
`INSERT INTO release_aliases (release_id, alias, source, is_primary)
VALUES (1, 'Release Group One', 'Testing', true),
(2, 'Release Group Two', 'Testing', true)`)
require.NoError(t, err)
// Insert tracks
err = store.Exec(context.Background(),
`INSERT INTO tracks (musicbrainz_id, release_id)
VALUES ('11111111-1111-1111-1111-111111111111', 1),
('22222222-2222-2222-2222-222222222222', 2)`)
require.NoError(t, err)
// Insert track aliases
err = store.Exec(context.Background(),
`INSERT INTO track_aliases (track_id, alias, source, is_primary)
VALUES (1, 'Track One', 'Testing', true),
(2, 'Track Two', 'Testing', true)`)
require.NoError(t, err)
// Associate tracks with artists
err = store.Exec(context.Background(),
`INSERT INTO artist_tracks (artist_id, track_id)
VALUES (1, 1), (2, 2)`)
require.NoError(t, err)
}
func TestGetTrack(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
// Test GetTrack by ID
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
require.NoError(t, err)
assert.Equal(t, int32(1), track.ID)
assert.Equal(t, "Track One", track.Title)
assert.Equal(t, uuid.MustParse("11111111-1111-1111-1111-111111111111"), *track.MbzID)
// Test GetTrack by MusicBrainzID
track, err = store.GetTrack(ctx, db.GetTrackOpts{MusicBrainzID: uuid.MustParse("22222222-2222-2222-2222-222222222222")})
require.NoError(t, err)
assert.Equal(t, int32(2), track.ID)
assert.Equal(t, "Track Two", track.Title)
// Test GetTrack by Title and ArtistIDs
track, err = store.GetTrack(ctx, db.GetTrackOpts{
Title: "Track One",
ArtistIDs: []int32{1},
})
require.NoError(t, err)
assert.Equal(t, int32(1), track.ID)
assert.Equal(t, "Track One", track.Title)
// Test GetTrack with insufficient information
_, err = store.GetTrack(ctx, db.GetTrackOpts{})
assert.Error(t, err)
}
func TestSaveTrack(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
// Test SaveTrack with valid inputs
track, err := store.SaveTrack(ctx, db.SaveTrackOpts{
Title: "New Track",
ArtistIDs: []int32{1},
RecordingMbzID: uuid.MustParse("33333333-3333-3333-3333-333333333333"),
AlbumID: 1,
})
require.NoError(t, err)
assert.Equal(t, "New Track", track.Title)
assert.Equal(t, uuid.MustParse("33333333-3333-3333-3333-333333333333"), *track.MbzID)
// Verify artist associations exist
exists, err := store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM artist_tracks
WHERE artist_id = $1 AND track_id = $2
)`, 1, track.ID)
require.NoError(t, err)
assert.True(t, exists, "expected artist association to exist")
// Verify alias exists
exists, err = store.RowExists(ctx, `
SELECT EXISTS (
SELECT 1 FROM track_aliases
WHERE track_id = $1 AND is_primary = true
)`, track.ID)
require.NoError(t, err)
assert.True(t, exists, "expected primary alias to exist")
// Test SaveTrack with missing ArtistIDs
_, err = store.SaveTrack(ctx, db.SaveTrackOpts{
Title: "Invalid Track",
ArtistIDs: []int32{},
RecordingMbzID: uuid.MustParse("44444444-4444-4444-4444-444444444444"),
})
assert.Error(t, err)
// Test SaveTrack with invalid ArtistIDs
_, err = store.SaveTrack(ctx, db.SaveTrackOpts{
Title: "Invalid Track",
ArtistIDs: []int32{0},
RecordingMbzID: uuid.MustParse("55555555-5555-5555-5555-555555555555"),
})
assert.Error(t, err)
}
func TestUpdateTrack(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
newMbzID := uuid.MustParse("66666666-6666-6666-6666-666666666666")
newDuration := 100
err := store.UpdateTrack(ctx, db.UpdateTrackOpts{
ID: 1,
MusicBrainzID: newMbzID,
Duration: int32(newDuration),
})
require.NoError(t, err)
// Verify the update
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
require.NoError(t, err)
require.Equal(t, newMbzID, *track.MbzID)
require.EqualValues(t, newDuration, track.Duration)
// Test UpdateTrack with missing ID
err = store.UpdateTrack(ctx, db.UpdateTrackOpts{
ID: 0,
MusicBrainzID: newMbzID,
Duration: int32(newDuration),
})
assert.Error(t, err)
// Test UpdateTrack with nil MusicBrainz ID
err = store.UpdateTrack(ctx, db.UpdateTrackOpts{
ID: 1,
MusicBrainzID: uuid.Nil,
Duration: int32(newDuration),
})
assert.NoError(t, err) // No update should occur
}
func TestTrackAliases(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
err := store.SaveTrackAliases(ctx, 1, []string{"Alias One", "Alias Two"}, "Testing")
require.NoError(t, err)
aliases, err := store.GetAllTrackAliases(ctx, 1)
require.NoError(t, err)
assert.Len(t, aliases, 3)
err = store.SetPrimaryTrackAlias(ctx, 1, "Alias One")
require.NoError(t, err)
track, err := store.GetTrack(ctx, db.GetTrackOpts{ID: 1})
require.NoError(t, err)
assert.Equal(t, "Alias One", track.Title)
err = store.SetPrimaryTrackAlias(ctx, 1, "Fake Alias")
require.Error(t, err)
store.SetPrimaryTrackAlias(ctx, 1, "Track One")
}
func TestDeleteTrack(t *testing.T) {
testDataForTracks(t)
ctx := context.Background()
err := store.DeleteTrack(ctx, 2)
require.NoError(t, err)
_, err = store.Count(ctx, `SELECT * FROM tracks WHERE id = 2`)
require.ErrorIs(t, err, pgx.ErrNoRows) // no rows error
}

219
internal/db/psql/user.go Normal file
View file

@ -0,0 +1,219 @@
package psql
import (
"context"
"errors"
"regexp"
"strings"
"unicode/utf8"
"github.com/gabehf/koito/internal/db"
"github.com/gabehf/koito/internal/logger"
"github.com/gabehf/koito/internal/models"
"github.com/gabehf/koito/internal/repository"
"github.com/jackc/pgx/v5"
"golang.org/x/crypto/bcrypt"
)
// Returns nil, nil when no database entries are found
func (d *Psql) GetUserByUsername(ctx context.Context, username string) (*models.User, error) {
row, err := d.q.GetUserByUsername(ctx, strings.ToLower(username))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &models.User{
ID: row.ID,
Username: row.Username,
Password: row.Password,
Role: models.UserRole(row.Role),
}, nil
}
// Returns nil, nil when no database entries are found
func (d *Psql) GetUserByApiKey(ctx context.Context, key string) (*models.User, error) {
row, err := d.q.GetUserByApiKey(ctx, key)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &models.User{
ID: row.ID,
Username: row.Username,
Password: row.Password,
Role: models.UserRole(row.Role),
}, nil
}
func (d *Psql) SaveUser(ctx context.Context, opts db.SaveUserOpts) (*models.User, error) {
l := logger.FromContext(ctx)
err := ValidateUsername(opts.Username)
if err != nil {
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
return nil, err
}
pw, err := ValidateAndNormalizePassword(opts.Password)
if err != nil {
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
return nil, err
}
if opts.Role == "" {
opts.Role = models.UserRoleUser
}
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
if err != nil {
l.Err(err).Msg("Failed to generate hashed password")
return nil, err
}
u, err := d.q.InsertUser(ctx, repository.InsertUserParams{
Username: strings.ToLower(opts.Username),
Password: hashPw,
Role: repository.Role(opts.Role),
})
if err != nil {
return nil, err
}
return &models.User{
ID: u.ID,
Username: u.Username,
Role: models.UserRole(u.Role),
}, nil
}
func (d *Psql) SaveApiKey(ctx context.Context, opts db.SaveApiKeyOpts) (*models.ApiKey, error) {
row, err := d.q.InsertApiKey(ctx, repository.InsertApiKeyParams{
Key: opts.Key,
Label: opts.Label,
UserID: opts.UserID,
})
if err != nil {
return nil, err
}
return &models.ApiKey{
ID: row.ID,
UserID: row.UserID,
Key: row.Key,
Label: row.Label,
CreatedAt: row.CreatedAt.Time,
}, nil
}
func (d *Psql) UpdateUser(ctx context.Context, opts db.UpdateUserOpts) error {
l := logger.FromContext(ctx)
if opts.ID == 0 {
return errors.New("user id is required")
}
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
l.Err(err).Msg("Failed to begin transaction")
return err
}
defer tx.Rollback(ctx)
qtx := d.q.WithTx(tx)
if opts.Username != "" {
err := ValidateUsername(opts.Username)
if err != nil {
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
return err
}
err = qtx.UpdateUserUsername(ctx, repository.UpdateUserUsernameParams{
ID: opts.ID,
Username: opts.Username,
})
if err != nil {
return err
}
}
if opts.Password != "" {
pw, err := ValidateAndNormalizePassword(opts.Password)
if err != nil {
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
return err
}
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
if err != nil {
l.Err(err).Msg("Failed to generate hashed password")
return err
}
err = qtx.UpdateUserPassword(ctx, repository.UpdateUserPasswordParams{
ID: opts.ID,
Password: hashPw,
})
if err != nil {
return err
}
}
return tx.Commit(ctx)
}
func (d *Psql) GetApiKeysByUserID(ctx context.Context, id int32) ([]models.ApiKey, error) {
rows, err := d.q.GetAllApiKeysByUserID(ctx, id)
if err != nil {
return nil, err
}
keys := make([]models.ApiKey, len(rows))
for i, row := range rows {
keys[i] = models.ApiKey{
ID: row.ID,
Key: row.Key,
Label: row.Label,
UserID: row.UserID,
}
}
return keys, nil
}
func (d *Psql) UpdateApiKeyLabel(ctx context.Context, opts db.UpdateApiKeyLabelOpts) error {
return d.q.UpdateApiKeyLabel(ctx, repository.UpdateApiKeyLabelParams{
ID: opts.ID,
Label: opts.Label,
UserID: opts.UserID,
})
}
func (d *Psql) DeleteApiKey(ctx context.Context, id int32) error {
return d.q.DeleteApiKey(ctx, id)
}
func (d *Psql) CountUsers(ctx context.Context) (int64, error) {
return d.q.CountUsers(ctx)
}
const (
maxUsernameLength = 32
minUsernameLength = 1
maxPasswordLength = 128
minPasswordLength = 8
)
var usernameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
func ValidateUsername(username string) error {
length := utf8.RuneCountInString(username)
if length < minUsernameLength || length > maxUsernameLength {
return errors.New("username must be between 1 and 32 characters")
}
if !usernameRegex.MatchString(username) {
return errors.New("username can only contain [a-zA-Z0-9_.-]")
}
return nil
}
func ValidateAndNormalizePassword(password string) (string, error) {
length := utf8.RuneCountInString(password)
if length < minPasswordLength {
return "", errors.New("password must be at least 8 characters long")
}
if length > maxPasswordLength {
var truncated []rune
for i, r := range password {
if i >= maxPasswordLength {
break
}
truncated = append(truncated, r)
}
password = string(truncated)
}
return password, nil
}

View file

@ -0,0 +1,199 @@
package psql_test
import (
"context"
"testing"
"github.com/gabehf/koito/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
)
func setupTestDataForUsers(t *testing.T) {
truncateTestDataForUsers(t)
// Insert additional test users
err := store.Exec(context.Background(),
`INSERT INTO users (username, password, role)
VALUES ('test_user', $1, 'user'),
('admin_user', $1, 'admin')`, []byte("hashed_password"))
require.NoError(t, err)
}
func truncateTestDataForUsers(t *testing.T) {
err := store.Exec(context.Background(),
`DELETE FROM users WHERE id NOT IN (1)`,
)
require.NoError(t, err)
err = store.Exec(context.Background(),
`ALTER SEQUENCE users_id_seq RESTART WITH 2`,
)
require.NoError(t, err)
err = store.Exec(context.Background(),
`TRUNCATE api_keys RESTART IDENTITY CASCADE`,
)
require.NoError(t, err)
}
func TestGetUserByUsername(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Test fetching an existing user
user, err := store.GetUserByUsername(ctx, "test_user")
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, "test_user", user.Username)
assert.Equal(t, "user", string(user.Role))
// Test fetching a non-existent user
user, err = store.GetUserByUsername(ctx, "nonexistent_user")
require.NoError(t, err)
assert.Nil(t, user)
}
func TestGetUserByApiKey(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Insert an API key for the test user
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('test_key', 'Test Key', 2)`)
require.NoError(t, err)
// Test fetching a user by API key
user, err := store.GetUserByApiKey(ctx, "test_key")
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, int32(2), user.ID)
assert.Equal(t, "test_user", user.Username)
// Test fetching a user with a non-existent API key
user, err = store.GetUserByApiKey(ctx, "nonexistent_key")
require.NoError(t, err)
assert.Nil(t, user)
}
func TestSaveUser(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Save a new user
opts := db.SaveUserOpts{
Username: "new_user",
Password: "secure_password",
Role: "user",
}
user, err := store.SaveUser(ctx, opts)
require.NoError(t, err)
require.NotNil(t, user)
assert.Equal(t, "new_user", user.Username)
assert.Equal(t, "user", string(user.Role))
// Verify the password was hashed
var hashedPassword []byte
err = store.QueryRow(ctx, `SELECT password FROM users WHERE username = $1`, "new_user").Scan(&hashedPassword)
require.NoError(t, err)
assert.NoError(t, bcrypt.CompareHashAndPassword(hashedPassword, []byte(opts.Password)))
// Test validation failures
_, err = store.SaveUser(ctx, db.SaveUserOpts{
Username: "Q!@JH(F_H@#!*HF#*)&@",
Password: "testpassword12345",
})
assert.Error(t, err)
_, err = store.SaveUser(ctx, db.SaveUserOpts{
Username: "test_user",
Password: "<3",
})
assert.Error(t, err)
}
func TestSaveApiKey(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Save an API key for the test user
label := "New API Key"
opts := db.SaveApiKeyOpts{
Key: "new_api_key",
Label: label,
UserID: 2,
}
_, err := store.SaveApiKey(ctx, opts)
require.NoError(t, err)
// Verify the API key was saved
count, err := store.Count(ctx, `SELECT COUNT(*) FROM api_keys WHERE key = $1 AND user_id = $2`, opts.Key, opts.UserID)
require.NoError(t, err)
assert.Equal(t, 1, count)
}
func TestGetApiKeysByUserID(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Insert API keys for the test user
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES
('key1', 'Key 1', 2),
('key2', 'Key 2', 2)`)
require.NoError(t, err)
// Fetch API keys for the test user
keys, err := store.GetApiKeysByUserID(ctx, 2)
require.NoError(t, err)
require.Len(t, keys, 2)
assert.Equal(t, "key1", keys[0].Key)
assert.Equal(t, "key2", keys[1].Key)
}
func TestUpdateApiKeyLabel(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Insert an API key for the test user
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('key_to_update', 'Old Label', 2)`)
require.NoError(t, err)
// Update the API key label
opts := db.UpdateApiKeyLabelOpts{
ID: 1,
Label: "Updated Label",
UserID: 2,
}
err = store.UpdateApiKeyLabel(ctx, opts)
require.NoError(t, err)
// Verify the label was updated
var label string
err = store.QueryRow(ctx, `SELECT label FROM api_keys WHERE id = $1`, opts.ID).Scan(&label)
require.NoError(t, err)
assert.Equal(t, "Updated Label", label)
}
func TestDeleteApiKey(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Insert an API key for the test user
err := store.Exec(ctx, `INSERT INTO api_keys (key, label, user_id) VALUES ('key_to_delete', 'Label', 2)`)
require.NoError(t, err)
// Delete the API key
err = store.DeleteApiKey(ctx, 1) // Assuming the ID is auto-generated and starts from 1
require.NoError(t, err)
// Verify the API key was deleted
count, err := store.Count(ctx, `SELECT COUNT(*) FROM api_keys WHERE id = $1`, 1)
require.NoError(t, err)
assert.Equal(t, 0, count)
}
func TestCountUsers(t *testing.T) {
ctx := context.Background()
setupTestDataForUsers(t)
// Count the number of users
count, err := store.Count(ctx, `SELECT COUNT(*) FROM users`)
require.NoError(t, err)
assert.GreaterOrEqual(t, count, 3) // Special user + test users
}