mirror of https://github.com/gabehf/Koito.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
220 lines
5.6 KiB
220 lines
5.6 KiB
package psql
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"regexp"
|
|
"strings"
|
|
"unicode/utf8"
|
|
|
|
"github.com/gabehf/koito/internal/db"
|
|
"github.com/gabehf/koito/internal/logger"
|
|
"github.com/gabehf/koito/internal/models"
|
|
"github.com/gabehf/koito/internal/repository"
|
|
"github.com/jackc/pgx/v5"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
// Returns nil, nil when no database entries are found
|
|
func (d *Psql) GetUserByUsername(ctx context.Context, username string) (*models.User, error) {
|
|
row, err := d.q.GetUserByUsername(ctx, strings.ToLower(username))
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
return &models.User{
|
|
ID: row.ID,
|
|
Username: row.Username,
|
|
Password: row.Password,
|
|
Role: models.UserRole(row.Role),
|
|
}, nil
|
|
}
|
|
|
|
// Returns nil, nil when no database entries are found
|
|
func (d *Psql) GetUserByApiKey(ctx context.Context, key string) (*models.User, error) {
|
|
row, err := d.q.GetUserByApiKey(ctx, key)
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
return &models.User{
|
|
ID: row.ID,
|
|
Username: row.Username,
|
|
Password: row.Password,
|
|
Role: models.UserRole(row.Role),
|
|
}, nil
|
|
}
|
|
|
|
func (d *Psql) SaveUser(ctx context.Context, opts db.SaveUserOpts) (*models.User, error) {
|
|
l := logger.FromContext(ctx)
|
|
err := ValidateUsername(opts.Username)
|
|
if err != nil {
|
|
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
|
|
return nil, err
|
|
}
|
|
pw, err := ValidateAndNormalizePassword(opts.Password)
|
|
if err != nil {
|
|
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
|
|
return nil, err
|
|
}
|
|
if opts.Role == "" {
|
|
opts.Role = models.UserRoleUser
|
|
}
|
|
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
l.Err(err).Msg("Failed to generate hashed password")
|
|
return nil, err
|
|
}
|
|
u, err := d.q.InsertUser(ctx, repository.InsertUserParams{
|
|
Username: strings.ToLower(opts.Username),
|
|
Password: hashPw,
|
|
Role: repository.Role(opts.Role),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &models.User{
|
|
ID: u.ID,
|
|
Username: u.Username,
|
|
Role: models.UserRole(u.Role),
|
|
}, nil
|
|
}
|
|
func (d *Psql) SaveApiKey(ctx context.Context, opts db.SaveApiKeyOpts) (*models.ApiKey, error) {
|
|
row, err := d.q.InsertApiKey(ctx, repository.InsertApiKeyParams{
|
|
Key: opts.Key,
|
|
Label: opts.Label,
|
|
UserID: opts.UserID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &models.ApiKey{
|
|
ID: row.ID,
|
|
UserID: row.UserID,
|
|
Key: row.Key,
|
|
Label: row.Label,
|
|
CreatedAt: row.CreatedAt.Time,
|
|
}, nil
|
|
}
|
|
|
|
func (d *Psql) UpdateUser(ctx context.Context, opts db.UpdateUserOpts) error {
|
|
l := logger.FromContext(ctx)
|
|
if opts.ID == 0 {
|
|
return errors.New("user id is required")
|
|
}
|
|
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
|
if err != nil {
|
|
l.Err(err).Msg("Failed to begin transaction")
|
|
return err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
qtx := d.q.WithTx(tx)
|
|
if opts.Username != "" {
|
|
err := ValidateUsername(opts.Username)
|
|
if err != nil {
|
|
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
|
|
return err
|
|
}
|
|
err = qtx.UpdateUserUsername(ctx, repository.UpdateUserUsernameParams{
|
|
ID: opts.ID,
|
|
Username: opts.Username,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if opts.Password != "" {
|
|
pw, err := ValidateAndNormalizePassword(opts.Password)
|
|
if err != nil {
|
|
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
|
|
return err
|
|
}
|
|
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
l.Err(err).Msg("Failed to generate hashed password")
|
|
return err
|
|
}
|
|
err = qtx.UpdateUserPassword(ctx, repository.UpdateUserPasswordParams{
|
|
ID: opts.ID,
|
|
Password: hashPw,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func (d *Psql) GetApiKeysByUserID(ctx context.Context, id int32) ([]models.ApiKey, error) {
|
|
rows, err := d.q.GetAllApiKeysByUserID(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
keys := make([]models.ApiKey, len(rows))
|
|
for i, row := range rows {
|
|
keys[i] = models.ApiKey{
|
|
ID: row.ID,
|
|
Key: row.Key,
|
|
Label: row.Label,
|
|
UserID: row.UserID,
|
|
}
|
|
}
|
|
return keys, nil
|
|
}
|
|
|
|
func (d *Psql) UpdateApiKeyLabel(ctx context.Context, opts db.UpdateApiKeyLabelOpts) error {
|
|
return d.q.UpdateApiKeyLabel(ctx, repository.UpdateApiKeyLabelParams{
|
|
ID: opts.ID,
|
|
Label: opts.Label,
|
|
UserID: opts.UserID,
|
|
})
|
|
}
|
|
|
|
func (d *Psql) DeleteApiKey(ctx context.Context, id int32) error {
|
|
return d.q.DeleteApiKey(ctx, id)
|
|
}
|
|
|
|
func (d *Psql) CountUsers(ctx context.Context) (int64, error) {
|
|
return d.q.CountUsers(ctx)
|
|
}
|
|
|
|
const (
|
|
maxUsernameLength = 32
|
|
minUsernameLength = 1
|
|
maxPasswordLength = 128
|
|
minPasswordLength = 8
|
|
)
|
|
|
|
var usernameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
|
|
|
|
func ValidateUsername(username string) error {
|
|
length := utf8.RuneCountInString(username)
|
|
if length < minUsernameLength || length > maxUsernameLength {
|
|
return errors.New("username must be between 1 and 32 characters")
|
|
}
|
|
if !usernameRegex.MatchString(username) {
|
|
return errors.New("username can only contain [a-zA-Z0-9_.-]")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateAndNormalizePassword(password string) (string, error) {
|
|
length := utf8.RuneCountInString(password)
|
|
if length < minPasswordLength {
|
|
return "", errors.New("password must be at least 8 characters long")
|
|
}
|
|
if length > maxPasswordLength {
|
|
var truncated []rune
|
|
for i, r := range password {
|
|
if i >= maxPasswordLength {
|
|
break
|
|
}
|
|
truncated = append(truncated, r)
|
|
}
|
|
password = string(truncated)
|
|
}
|
|
return password, nil
|
|
}
|