mirror of
https://github.com/gabehf/Koito.git
synced 2026-03-16 19:05:54 -07:00
chore: initial public commit
This commit is contained in:
commit
fc9054b78c
250 changed files with 32809 additions and 0 deletions
219
internal/db/psql/user.go
Normal file
219
internal/db/psql/user.go
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gabehf/koito/internal/db"
|
||||
"github.com/gabehf/koito/internal/logger"
|
||||
"github.com/gabehf/koito/internal/models"
|
||||
"github.com/gabehf/koito/internal/repository"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Returns nil, nil when no database entries are found
|
||||
func (d *Psql) GetUserByUsername(ctx context.Context, username string) (*models.User, error) {
|
||||
row, err := d.q.GetUserByUsername(ctx, strings.ToLower(username))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.User{
|
||||
ID: row.ID,
|
||||
Username: row.Username,
|
||||
Password: row.Password,
|
||||
Role: models.UserRole(row.Role),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Returns nil, nil when no database entries are found
|
||||
func (d *Psql) GetUserByApiKey(ctx context.Context, key string) (*models.User, error) {
|
||||
row, err := d.q.GetUserByApiKey(ctx, key)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.User{
|
||||
ID: row.ID,
|
||||
Username: row.Username,
|
||||
Password: row.Password,
|
||||
Role: models.UserRole(row.Role),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) SaveUser(ctx context.Context, opts db.SaveUserOpts) (*models.User, error) {
|
||||
l := logger.FromContext(ctx)
|
||||
err := ValidateUsername(opts.Username)
|
||||
if err != nil {
|
||||
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
|
||||
return nil, err
|
||||
}
|
||||
pw, err := ValidateAndNormalizePassword(opts.Password)
|
||||
if err != nil {
|
||||
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
|
||||
return nil, err
|
||||
}
|
||||
if opts.Role == "" {
|
||||
opts.Role = models.UserRoleUser
|
||||
}
|
||||
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to generate hashed password")
|
||||
return nil, err
|
||||
}
|
||||
u, err := d.q.InsertUser(ctx, repository.InsertUserParams{
|
||||
Username: strings.ToLower(opts.Username),
|
||||
Password: hashPw,
|
||||
Role: repository.Role(opts.Role),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.User{
|
||||
ID: u.ID,
|
||||
Username: u.Username,
|
||||
Role: models.UserRole(u.Role),
|
||||
}, nil
|
||||
}
|
||||
func (d *Psql) SaveApiKey(ctx context.Context, opts db.SaveApiKeyOpts) (*models.ApiKey, error) {
|
||||
row, err := d.q.InsertApiKey(ctx, repository.InsertApiKeyParams{
|
||||
Key: opts.Key,
|
||||
Label: opts.Label,
|
||||
UserID: opts.UserID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &models.ApiKey{
|
||||
ID: row.ID,
|
||||
UserID: row.UserID,
|
||||
Key: row.Key,
|
||||
Label: row.Label,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Psql) UpdateUser(ctx context.Context, opts db.UpdateUserOpts) error {
|
||||
l := logger.FromContext(ctx)
|
||||
if opts.ID == 0 {
|
||||
return errors.New("user id is required")
|
||||
}
|
||||
tx, err := d.conn.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to begin transaction")
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
qtx := d.q.WithTx(tx)
|
||||
if opts.Username != "" {
|
||||
err := ValidateUsername(opts.Username)
|
||||
if err != nil {
|
||||
l.Debug().AnErr("validator_notice", err).Msgf("Username failed validation: %s", opts.Username)
|
||||
return err
|
||||
}
|
||||
err = qtx.UpdateUserUsername(ctx, repository.UpdateUserUsernameParams{
|
||||
ID: opts.ID,
|
||||
Username: opts.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opts.Password != "" {
|
||||
pw, err := ValidateAndNormalizePassword(opts.Password)
|
||||
if err != nil {
|
||||
l.Debug().AnErr("validator_notice", err).Msgf("Password failed validation")
|
||||
return err
|
||||
}
|
||||
hashPw, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
l.Err(err).Msg("Failed to generate hashed password")
|
||||
return err
|
||||
}
|
||||
err = qtx.UpdateUserPassword(ctx, repository.UpdateUserPasswordParams{
|
||||
ID: opts.ID,
|
||||
Password: hashPw,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (d *Psql) GetApiKeysByUserID(ctx context.Context, id int32) ([]models.ApiKey, error) {
|
||||
rows, err := d.q.GetAllApiKeysByUserID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys := make([]models.ApiKey, len(rows))
|
||||
for i, row := range rows {
|
||||
keys[i] = models.ApiKey{
|
||||
ID: row.ID,
|
||||
Key: row.Key,
|
||||
Label: row.Label,
|
||||
UserID: row.UserID,
|
||||
}
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (d *Psql) UpdateApiKeyLabel(ctx context.Context, opts db.UpdateApiKeyLabelOpts) error {
|
||||
return d.q.UpdateApiKeyLabel(ctx, repository.UpdateApiKeyLabelParams{
|
||||
ID: opts.ID,
|
||||
Label: opts.Label,
|
||||
UserID: opts.UserID,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Psql) DeleteApiKey(ctx context.Context, id int32) error {
|
||||
return d.q.DeleteApiKey(ctx, id)
|
||||
}
|
||||
|
||||
func (d *Psql) CountUsers(ctx context.Context) (int64, error) {
|
||||
return d.q.CountUsers(ctx)
|
||||
}
|
||||
|
||||
const (
|
||||
maxUsernameLength = 32
|
||||
minUsernameLength = 1
|
||||
maxPasswordLength = 128
|
||||
minPasswordLength = 8
|
||||
)
|
||||
|
||||
var usernameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
|
||||
|
||||
func ValidateUsername(username string) error {
|
||||
length := utf8.RuneCountInString(username)
|
||||
if length < minUsernameLength || length > maxUsernameLength {
|
||||
return errors.New("username must be between 1 and 32 characters")
|
||||
}
|
||||
if !usernameRegex.MatchString(username) {
|
||||
return errors.New("username can only contain [a-zA-Z0-9_.-]")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAndNormalizePassword(password string) (string, error) {
|
||||
length := utf8.RuneCountInString(password)
|
||||
if length < minPasswordLength {
|
||||
return "", errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
if length > maxPasswordLength {
|
||||
var truncated []rune
|
||||
for i, r := range password {
|
||||
if i >= maxPasswordLength {
|
||||
break
|
||||
}
|
||||
truncated = append(truncated, r)
|
||||
}
|
||||
password = string(truncated)
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue