Merge user and account (#701)

This commit is contained in:
Sung 2025-10-19 21:05:47 -07:00 committed by GitHub
commit 0a5728faf3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 248 additions and 477 deletions

View file

@ -136,8 +136,7 @@ func TestMain(m *testing.M) {
// helpers
func setupUser(t *testing.T, db *cliDatabase.DB) database.User {
user := apitest.SetupUserData(serverDb)
apitest.SetupAccountData(serverDb, user, "alice@example.com", "pass1234")
user := apitest.SetupUserData(serverDb, "alice@example.com", "pass1234")
return user
}
@ -4255,8 +4254,7 @@ func TestSync_EmptyServer(t *testing.T) {
// Step 1: Set up user on Server A and sync
apiEndpointA := fmt.Sprintf("%s/api", serverA.URL)
userA := apitest.SetupUserData(serverDbA)
apitest.SetupAccountData(serverDbA, userA, "alice@example.com", "pass1234")
userA := apitest.SetupUserData(serverDbA, "alice@example.com", "pass1234")
sessionA := apitest.SetupSession(serverDbA, userA)
cliDatabase.MustExec(t, "inserting session_key", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKey, sessionA.Key)
cliDatabase.MustExec(t, "inserting session_key_expiry", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKeyExpiry, sessionA.ExpiresAt.Unix())
@ -4280,8 +4278,7 @@ func TestSync_EmptyServer(t *testing.T) {
apiEndpointB := fmt.Sprintf("%s/api", serverB.URL)
// Set up user on Server B
userB := apitest.SetupUserData(serverDbB)
apitest.SetupAccountData(serverDbB, userB, "alice@example.com", "pass1234")
userB := apitest.SetupUserData(serverDbB, "alice@example.com", "pass1234")
sessionB := apitest.SetupSession(serverDbB, userB)
cliDatabase.MustExec(t, "updating session_key for B", ctx.DB, "UPDATE system SET value = ? WHERE key = ?", sessionB.Key, consts.SystemSessionKey)
cliDatabase.MustExec(t, "updating session_key_expiry for B", ctx.DB, "UPDATE system SET value = ? WHERE key = ?", sessionB.ExpiresAt.Unix(), consts.SystemSessionKeyExpiry)

View file

@ -56,10 +56,10 @@ func TestCreateBook(t *testing.T) {
func() {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
a := NewTest()
@ -122,10 +122,10 @@ func TestDeleteBook(t *testing.T) {
func() {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
book := database.Book{UserID: user.ID, Label: "js", Deleted: false}
@ -201,10 +201,10 @@ func TestUpdateBook(t *testing.T) {
func() {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel}

View file

@ -48,7 +48,7 @@ func TestIncremenetUserUSN(t *testing.T) {
func() {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
// execute

View file

@ -77,11 +77,11 @@ func TestCreateNote(t *testing.T) {
func() {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
fmt.Println(user)
anotherUser := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
@ -130,7 +130,7 @@ func TestCreateNote(t *testing.T) {
func TestCreateNote_EmptyBody(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
@ -169,10 +169,10 @@ func TestUpdateNote(t *testing.T) {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
anotherUser := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
@ -234,7 +234,7 @@ func TestUpdateNote(t *testing.T) {
func TestUpdateNote_SameContent(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
@ -291,10 +291,10 @@ func TestDeleteNote(t *testing.T) {
func() {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "testBook"}
@ -351,7 +351,7 @@ func TestDeleteNote(t *testing.T) {
func TestGetNotes_FTSSearch(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
@ -415,7 +415,7 @@ func TestGetNotes_FTSSearch(t *testing.T) {
func TestGetNotes_FTSSearch_Snippet(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
@ -449,7 +449,7 @@ func TestGetNotes_FTSSearch_Snippet(t *testing.T) {
func TestGetNotes_FTSSearch_ShortWord(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
@ -481,7 +481,7 @@ func TestGetNotes_FTSSearch_ShortWord(t *testing.T) {
func TestGetNotes_All(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")

View file

@ -65,7 +65,7 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
tx := a.DB.Begin()
var count int64
if err := tx.Model(database.Account{}).Where("email = ?", email).Count(&count).Error; err != nil {
if err := tx.Model(&database.User{}).Where("email = ?", email).Count(&count).Error; err != nil {
return database.User{}, pkgErrors.Wrap(err, "counting user")
}
if count > 0 {
@ -85,21 +85,14 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
}
user := database.User{
UUID: uuid,
UUID: uuid,
Email: database.ToNullString(email),
Password: database.ToNullString(string(hashedPassword)),
}
if err = tx.Save(&user).Error; err != nil {
tx.Rollback()
return database.User{}, pkgErrors.Wrap(err, "saving user")
}
account := database.Account{
Email: database.ToNullString(email),
Password: database.ToNullString(string(hashedPassword)),
UserID: user.ID,
}
if err = tx.Save(&account).Error; err != nil {
tx.Rollback()
return database.User{}, pkgErrors.Wrap(err, "saving account")
}
if err := a.TouchLastLoginAt(user, tx); err != nil {
tx.Rollback()
@ -111,42 +104,36 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
return user, nil
}
// GetAccountByEmail finds an account by email
func (a *App) GetAccountByEmail(email string) (*database.Account, error) {
var account database.Account
err := a.DB.Where("email = ?", email).First(&account).Error
// GetUserByEmail finds a user by email
func (a *App) GetUserByEmail(email string) (*database.User, error) {
var user database.User
err := a.DB.Where("email = ?", email).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrNotFound
} else if err != nil {
return nil, err
}
return &account, nil
return &user, nil
}
// Authenticate authenticates a user
func (a *App) Authenticate(email, password string) (*database.User, error) {
account, err := a.GetAccountByEmail(email)
user, err := a.GetUserByEmail(email)
if err != nil {
return nil, err
}
err = bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(password))
err = bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte(password))
if err != nil {
return nil, ErrLoginInvalid
}
var user database.User
err = a.DB.Where("id = ?", account.UserID).First(&user).Error
if err != nil {
return nil, pkgErrors.Wrap(err, "finding user")
}
return &user, nil
return user, nil
}
// UpdateAccountPassword updates an account's password with validation
func UpdateAccountPassword(db *gorm.DB, account *database.Account, newPassword string) error {
// UpdateUserPassword updates a user's password with validation
func UpdateUserPassword(db *gorm.DB, user *database.User, newPassword string) error {
// Validate password
if err := validatePassword(newPassword); err != nil {
return err
@ -159,25 +146,25 @@ func UpdateAccountPassword(db *gorm.DB, account *database.Account, newPassword s
}
// Update the password
if err := db.Model(&account).Update("password", string(hashedPassword)).Error; err != nil {
if err := db.Model(&user).Update("password", string(hashedPassword)).Error; err != nil {
return pkgErrors.Wrap(err, "updating password")
}
return nil
}
// RemoveUser removes a user and their account from the system
// RemoveUser removes a user from the system
// Returns an error if the user has any notes or books
func (a *App) RemoveUser(email string) error {
// Find the account and user
account, err := a.GetAccountByEmail(email)
// Find the user
user, err := a.GetUserByEmail(email)
if err != nil {
return err
}
// Check if user has any notes
var noteCount int64
if err := a.DB.Model(&database.Note{}).Where("user_id = ? AND deleted = ?", account.UserID, false).Count(&noteCount).Error; err != nil {
if err := a.DB.Model(&database.Note{}).Where("user_id = ? AND deleted = ?", user.ID, false).Count(&noteCount).Error; err != nil {
return pkgErrors.Wrap(err, "counting notes")
}
if noteCount > 0 {
@ -186,34 +173,18 @@ func (a *App) RemoveUser(email string) error {
// Check if user has any books
var bookCount int64
if err := a.DB.Model(&database.Book{}).Where("user_id = ? AND deleted = ?", account.UserID, false).Count(&bookCount).Error; err != nil {
if err := a.DB.Model(&database.Book{}).Where("user_id = ? AND deleted = ?", user.ID, false).Count(&bookCount).Error; err != nil {
return pkgErrors.Wrap(err, "counting books")
}
if bookCount > 0 {
return ErrUserHasExistingResources
}
// Delete account and user in a transaction
tx := a.DB.Begin()
if err := tx.Delete(&account).Error; err != nil {
tx.Rollback()
return pkgErrors.Wrap(err, "deleting account")
}
var user database.User
if err := tx.Where("id = ?", account.UserID).First(&user).Error; err != nil {
tx.Rollback()
return pkgErrors.Wrap(err, "finding user")
}
if err := tx.Delete(&user).Error; err != nil {
tx.Rollback()
// Delete user
if err := a.DB.Delete(&user).Error; err != nil {
return pkgErrors.Wrap(err, "deleting user")
}
tx.Commit()
return nil
}

View file

@ -82,21 +82,20 @@ func TestCreateUser_ProValue(t *testing.T) {
}
func TestGetAccountByEmail(t *testing.T) {
func TestGetUserByEmail(t *testing.T) {
t.Run("success", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
user := testutils.SetupUserData(db, "alice@example.com", "password123")
a := NewTest()
a.DB = db
account, err := a.GetAccountByEmail("alice@example.com")
foundUser, err := a.GetUserByEmail("alice@example.com")
assert.Equal(t, err, nil, "should not error")
assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch")
assert.Equal(t, account.UserID, user.ID, "user ID mismatch")
assert.Equal(t, foundUser.Email.String, "alice@example.com", "email mismatch")
assert.Equal(t, foundUser.ID, user.ID, "user ID mismatch")
})
t.Run("not found", func(t *testing.T) {
@ -105,10 +104,10 @@ func TestGetAccountByEmail(t *testing.T) {
a := NewTest()
a.DB = db
account, err := a.GetAccountByEmail("nonexistent@example.com")
user, err := a.GetUserByEmail("nonexistent@example.com")
assert.Equal(t, err, ErrNotFound, "should return ErrNotFound")
assert.Equal(t, account, (*database.Account)(nil), "account should be nil")
assert.Equal(t, user, (*database.User)(nil), "user should be nil")
})
}
@ -124,25 +123,21 @@ func TestCreateUser(t *testing.T) {
var userCount int64
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, userCount, int64(1), "book count mismatch")
assert.Equal(t, userCount, int64(1), "user count mismatch")
var accountCount int64
var accountRecord database.Account
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, db.First(&accountRecord), "finding account")
var userRecord database.User
testutils.MustExec(t, db.First(&userRecord), "finding user")
assert.Equal(t, accountCount, int64(1), "account count mismatch")
assert.Equal(t, accountRecord.Email.String, "alice@example.com", "account email mismatch")
assert.Equal(t, userRecord.Email.String, "alice@example.com", "user email mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(accountRecord.Password.String), []byte("pass1234"))
passwordErr := bcrypt.CompareHashAndPassword([]byte(userRecord.Password.String), []byte("pass1234"))
assert.Equal(t, passwordErr, nil, "Password mismatch")
})
t.Run("duplicate email", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
aliceUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, aliceUser, "alice@example.com", "somepassword")
testutils.SetupUserData(db, "alice@example.com", "somepassword")
a := NewTest()
a.DB = db
@ -150,116 +145,109 @@ func TestCreateUser(t *testing.T) {
assert.Equal(t, err, ErrDuplicateEmail, "error mismatch")
var userCount, accountCount int64
var userCount int64
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
assert.Equal(t, userCount, int64(1), "user count mismatch")
assert.Equal(t, accountCount, int64(1), "account count mismatch")
})
}
func TestUpdateAccountPassword(t *testing.T) {
func TestUpdateUserPassword(t *testing.T) {
t.Run("success", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
err := UpdateAccountPassword(db, &account, "newpassword123")
err := UpdateUserPassword(db, &user, "newpassword123")
assert.Equal(t, err, nil, "should not error")
// Verify password was updated in database
var updatedAccount database.Account
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&updatedAccount), "finding updated account")
var updatedUser database.User
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&updatedUser), "finding updated user")
// Verify new password works
passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123"))
passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123"))
assert.Equal(t, passwordErr, nil, "New password should match")
// Verify old password no longer works
oldPasswordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("oldpassword123"))
oldPasswordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("oldpassword123"))
assert.NotEqual(t, oldPasswordErr, nil, "Old password should not match")
})
t.Run("password too short", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
err := UpdateAccountPassword(db, &account, "short")
err := UpdateUserPassword(db, &user, "short")
assert.Equal(t, err, ErrPasswordTooShort, "should return ErrPasswordTooShort")
// Verify password was NOT updated in database
var unchangedAccount database.Account
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account")
var unchangedUser database.User
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user")
// Verify old password still works
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123"))
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123"))
assert.Equal(t, passwordErr, nil, "Old password should still match")
})
t.Run("empty password", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
err := UpdateAccountPassword(db, &account, "")
err := UpdateUserPassword(db, &user, "")
assert.Equal(t, err, ErrPasswordTooShort, "should return ErrPasswordTooShort")
// Verify password was NOT updated in database
var unchangedAccount database.Account
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account")
var unchangedUser database.User
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user")
// Verify old password still works
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123"))
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123"))
assert.Equal(t, passwordErr, nil, "Old password should still match")
})
t.Run("transaction rollback", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
// Start a transaction and rollback to verify UpdateAccountPassword respects transactions
// Start a transaction and rollback to verify UpdateUserPassword respects transactions
tx := db.Begin()
err := UpdateAccountPassword(tx, &account, "newpassword123")
err := UpdateUserPassword(tx, &user, "newpassword123")
assert.Equal(t, err, nil, "should not error")
tx.Rollback()
// Verify password was NOT updated after rollback
var unchangedAccount database.Account
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account")
var unchangedUser database.User
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user")
// Verify old password still works
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123"))
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123"))
assert.Equal(t, passwordErr, nil, "Old password should still match after rollback")
})
t.Run("transaction commit", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
// Start a transaction and commit to verify UpdateAccountPassword respects transactions
// Start a transaction and commit to verify UpdateUserPassword respects transactions
tx := db.Begin()
err := UpdateAccountPassword(tx, &account, "newpassword123")
err := UpdateUserPassword(tx, &user, "newpassword123")
assert.Equal(t, err, nil, "should not error")
tx.Commit()
// Verify password was updated after commit
var updatedAccount database.Account
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&updatedAccount), "finding updated account")
var updatedUser database.User
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&updatedUser), "finding updated user")
// Verify new password works
passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123"))
passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123"))
assert.Equal(t, passwordErr, nil, "New password should match after commit")
})
}
@ -268,8 +256,7 @@ func TestRemoveUser(t *testing.T) {
t.Run("success", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
testutils.SetupUserData(db, "alice@example.com", "password123")
a := NewTest()
a.DB = db
@ -282,11 +269,6 @@ func TestRemoveUser(t *testing.T) {
var userCount int64
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users")
assert.Equal(t, userCount, int64(0), "user should be deleted")
// Verify account was deleted
var accountCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts")
assert.Equal(t, accountCount, int64(0), "account should be deleted")
})
t.Run("user not found", func(t *testing.T) {
@ -303,8 +285,7 @@ func TestRemoveUser(t *testing.T) {
t.Run("user has notes", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
user := testutils.SetupUserData(db, "alice@example.com", "password123")
book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false}
testutils.MustExec(t, db.Save(&book), "creating book")
@ -324,17 +305,12 @@ func TestRemoveUser(t *testing.T) {
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users")
assert.Equal(t, userCount, int64(1), "user should not be deleted")
// Verify account was NOT deleted
var accountCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts")
assert.Equal(t, accountCount, int64(1), "account should not be deleted")
})
t.Run("user has books", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
user := testutils.SetupUserData(db, "alice@example.com", "password123")
book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false}
testutils.MustExec(t, db.Save(&book), "creating book")
@ -351,17 +327,12 @@ func TestRemoveUser(t *testing.T) {
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users")
assert.Equal(t, userCount, int64(1), "user should not be deleted")
// Verify account was NOT deleted
var accountCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts")
assert.Equal(t, accountCount, int64(1), "account should not be deleted")
})
t.Run("user has deleted notes and books", func(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
user := testutils.SetupUserData(db, "alice@example.com", "password123")
book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false}
testutils.MustExec(t, db.Save(&book), "creating book")
@ -385,9 +356,5 @@ func TestRemoveUser(t *testing.T) {
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users")
assert.Equal(t, userCount, int64(0), "user should be deleted")
// Verify account was deleted
var accountCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts")
assert.Equal(t, accountCount, int64(0), "account should be deleted")
})
}

View file

@ -81,12 +81,12 @@ func userRemoveCmd(args []string, stdin io.Reader) {
defer cleanup()
// Check if user exists first
_, err := a.GetAccountByEmail(*email)
_, err := a.GetUserByEmail(*email)
if err != nil {
if errors.Is(err, app.ErrNotFound) {
fmt.Printf("Error: user with email %s not found\n", *email)
} else {
log.ErrorWrap(err, "finding account")
log.ErrorWrap(err, "finding user")
}
os.Exit(1)
}
@ -133,19 +133,19 @@ func userResetPasswordCmd(args []string) {
a, cleanup := setupAppWithDB(fs, *dbPath)
defer cleanup()
// Find the account
account, err := a.GetAccountByEmail(*email)
// Find the user
user, err := a.GetUserByEmail(*email)
if err != nil {
if errors.Is(err, app.ErrNotFound) {
fmt.Printf("Error: user with email %s not found\n", *email)
} else {
log.ErrorWrap(err, "finding account")
log.ErrorWrap(err, "finding user")
}
os.Exit(1)
}
// Update the password
if err := app.UpdateAccountPassword(a.DB, account, *password); err != nil {
if err := app.UpdateUserPassword(a.DB, user, *password); err != nil {
log.ErrorWrap(err, "updating password")
os.Exit(1)
}

View file

@ -45,9 +45,9 @@ func TestUserCreateCmd(t *testing.T) {
testutils.MustExec(t, db.Model(&database.User{}).Count(&count), "counting users")
assert.Equal(t, count, int64(1), "should have 1 user")
var account database.Account
testutils.MustExec(t, db.Where("email = ?", "test@example.com").First(&account), "finding account")
assert.Equal(t, account.Email.String, "test@example.com", "email mismatch")
var user database.User
testutils.MustExec(t, db.Where("email = ?", "test@example.com").First(&user), "finding user")
assert.Equal(t, user.Email.String, "test@example.com", "email mismatch")
}
func TestUserRemoveCmd(t *testing.T) {
@ -55,8 +55,7 @@ func TestUserRemoveCmd(t *testing.T) {
// Create a user first
db := testutils.InitDB(tmpDB)
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "test@example.com", "password123")
testutils.SetupUserData(db, "test@example.com", "password123")
sqlDB, _ := db.DB()
sqlDB.Close()
@ -81,9 +80,8 @@ func TestUserResetPasswordCmd(t *testing.T) {
// Create a user first
db := testutils.InitDB(tmpDB)
user := testutils.SetupUserData(db)
account := testutils.SetupAccountData(db, user, "test@example.com", "oldpassword123")
oldPasswordHash := account.Password.String
user := testutils.SetupUserData(db, "test@example.com", "oldpassword123")
oldPasswordHash := user.Password.String
sqlDB, _ := db.DB()
sqlDB.Close()
@ -97,18 +95,18 @@ func TestUserResetPasswordCmd(t *testing.T) {
sqlDB2.Close()
}()
var updatedAccount database.Account
testutils.MustExec(t, db2.Where("email = ?", "test@example.com").First(&updatedAccount), "finding account")
var updatedUser database.User
testutils.MustExec(t, db2.Where("email = ?", "test@example.com").First(&updatedUser), "finding user")
// Verify password hash changed
assert.Equal(t, updatedAccount.Password.String != oldPasswordHash, true, "password hash should be different")
assert.Equal(t, len(updatedAccount.Password.String) > 0, true, "password should be set")
assert.Equal(t, updatedUser.Password.String != oldPasswordHash, true, "password hash should be different")
assert.Equal(t, len(updatedUser.Password.String) > 0, true, "password should be set")
// Verify new password works
err := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123"))
err := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123"))
assert.Equal(t, err, nil, "new password should match")
// Verify old password doesn't work
err = bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("oldpassword123"))
err = bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("oldpassword123"))
assert.Equal(t, err != nil, true, "old password should not match")
}

View file

@ -25,9 +25,8 @@ import (
)
const (
userKey privateKey = "user"
accountKey privateKey = "account"
tokenKey privateKey = "token"
userKey privateKey = "user"
tokenKey privateKey = "token"
)
type privateKey string
@ -37,11 +36,6 @@ func WithUser(ctx context.Context, user *database.User) context.Context {
return context.WithValue(ctx, userKey, user)
}
// WithAccount creates a new context with the given account
func WithAccount(ctx context.Context, account *database.Account) context.Context {
return context.WithValue(ctx, accountKey, account)
}
// WithToken creates a new context with the given user
func WithToken(ctx context.Context, tok *database.Token) context.Context {
return context.WithValue(ctx, tokenKey, tok)
@ -59,17 +53,6 @@ func User(ctx context.Context) *database.User {
return nil
}
// Account retrieves an account from the given context.
func Account(ctx context.Context) *database.Account {
if temp := ctx.Value(accountKey); temp != nil {
if account, ok := temp.(*database.Account); ok {
return account
}
}
return nil
}
// Token retrieves a token from the given context.
func Token(ctx context.Context) *database.Token {
if temp := ctx.Value(tokenKey); temp != nil {

View file

@ -50,10 +50,8 @@ func TestGetBooks(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
@ -143,10 +141,8 @@ func TestGetBooksByName(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
@ -212,10 +208,8 @@ func TestGetBook(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
@ -276,10 +270,8 @@ func TestGetBookNonOwner(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
nonOwner := testutils.SetupUserData(db)
testutils.SetupAccountData(db, nonOwner, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
nonOwner := testutils.SetupUserData(db, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
@ -314,8 +306,7 @@ func TestCreateBook(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`)
@ -375,8 +366,7 @@ func TestCreateBook(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
@ -465,8 +455,7 @@ func TestUpdateBook(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
@ -550,11 +539,9 @@ func TestDeleteBook(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
b1 := database.Book{

View file

@ -63,10 +63,8 @@ func TestGetNotes(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
@ -187,10 +185,8 @@ func TestGetNote(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "user@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "another@test.com", "pass1234")
user := testutils.SetupUserData(db, "user@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db, "another@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
@ -318,8 +314,7 @@ func TestCreateNote(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
@ -400,8 +395,7 @@ func TestDeleteNote(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn")
b1 := database.Book{
@ -559,8 +553,7 @@ func TestUpdateNote(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")

View file

@ -307,23 +307,23 @@ func (u *Users) CreateResetToken(w http.ResponseWriter, r *http.Request) {
return
}
var account database.Account
err := u.app.DB.Where("email = ?", form.Email).First(&account).Error
var user database.User
err := u.app.DB.Where("email = ?", form.Email).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return
}
if err != nil {
handleHTMLError(w, r, err, "finding account", u.PasswordResetView, vd)
handleHTMLError(w, r, err, "finding user", u.PasswordResetView, vd)
return
}
resetToken, err := token.Create(u.app.DB, account.UserID, database.TokenTypeResetPassword)
resetToken, err := token.Create(u.app.DB, user.ID, database.TokenTypeResetPassword)
if err != nil {
handleHTMLError(w, r, err, "generating token", u.PasswordResetView, vd)
return
}
if err := u.app.SendPasswordResetEmail(account.Email.String, resetToken.Value); err != nil {
if err := u.app.SendPasswordResetEmail(user.Email.String, resetToken.Value); err != nil {
handleHTMLError(w, r, err, "sending password reset email", u.PasswordResetView, vd)
return
}
@ -396,8 +396,8 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) {
return
}
var account database.Account
if err := u.app.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
var user database.User
if err := u.app.DB.Where("id = ?", token.UserID).First(&user).Error; err != nil {
handleHTMLError(w, r, err, "finding user", u.PasswordResetConfirmView, vd)
return
}
@ -405,7 +405,7 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) {
tx := u.app.DB.Begin()
// Update the password
if err := app.UpdateAccountPassword(tx, &account, params.Password); err != nil {
if err := app.UpdateUserPassword(tx, &user, params.Password); err != nil {
tx.Rollback()
handleHTMLError(w, r, err, "updating password", u.PasswordResetConfirmView, vd)
return
@ -417,7 +417,7 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) {
return
}
if err := u.app.DeleteUserSessions(tx, account.UserID); err != nil {
if err := u.app.DeleteUserSessions(tx, user.ID); err != nil {
tx.Rollback()
handleHTMLError(w, r, err, "deleting user sessions", u.PasswordResetConfirmView, vd)
return
@ -425,19 +425,13 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) {
tx.Commit()
var user database.User
if err := u.app.DB.Where("id = ?", account.UserID).First(&user).Error; err != nil {
handleHTMLError(w, r, err, "finding user", u.PasswordResetConfirmView, vd)
return
}
alert := views.Alert{
Level: views.AlertLvlSuccess,
Message: "Password reset successful",
}
views.RedirectAlert(w, r, "/login", http.StatusFound, alert)
if err := u.app.SendPasswordResetAlertEmail(account.Email.String); err != nil {
if err := u.app.SendPasswordResetAlertEmail(user.Email.String); err != nil {
log.ErrorWrap(err, "sending password reset email")
}
}
@ -493,14 +487,8 @@ func (u *Users) PasswordUpdate(w http.ResponseWriter, r *http.Request) {
return
}
var account database.Account
if err := u.app.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
handleHTMLError(w, r, err, "getting account", u.SettingView, vd)
return
}
password := []byte(form.OldPassword)
if err := bcrypt.CompareHashAndPassword([]byte(account.Password.String), password); err != nil {
if err := bcrypt.CompareHashAndPassword([]byte(user.Password.String), password); err != nil {
log.WithFields(log.Fields{
"user_id": user.ID,
}).Warn("invalid password update attempt")
@ -508,7 +496,7 @@ func (u *Users) PasswordUpdate(w http.ResponseWriter, r *http.Request) {
return
}
if err := app.UpdateAccountPassword(u.app.DB, &account, form.NewPassword); err != nil {
if err := app.UpdateUserPassword(u.app.DB, user, form.NewPassword); err != nil {
handleHTMLError(w, r, err, "updating password", u.SettingView, vd)
return
}
@ -534,12 +522,6 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) {
return
}
var account database.Account
if err := u.app.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
handleHTMLError(w, r, err, "getting account", u.SettingView, vd)
return
}
var form updateProfileForm
if err := parseRequestData(r, &form); err != nil {
handleHTMLError(w, r, err, "parsing payload", u.SettingView, vd)
@ -547,7 +529,7 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) {
}
password := []byte(form.Password)
if err := bcrypt.CompareHashAndPassword([]byte(account.Password.String), password); err != nil {
if err := bcrypt.CompareHashAndPassword([]byte(user.Password.String), password); err != nil {
log.WithFields(log.Fields{
"user_id": user.ID,
}).Warn("invalid email update attempt")
@ -561,23 +543,13 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) {
return
}
tx := u.app.DB.Begin()
if err := tx.Save(&user).Error; err != nil {
tx.Rollback()
user.Email.String = form.Email
if err := u.app.DB.Save(&user).Error; err != nil {
handleHTMLError(w, r, err, "saving user", u.SettingView, vd)
return
}
account.Email.String = form.Email
if err := tx.Save(&account).Error; err != nil {
tx.Rollback()
handleHTMLError(w, r, err, "saving account", u.SettingView, vd)
return
}
tx.Commit()
alert := views.Alert{
Level: views.AlertLvlSuccess,
Message: "Email change successful",

View file

@ -98,15 +98,14 @@ func TestJoin(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusFound, "")
var account database.Account
testutils.MustExec(t, db.Where("email = ?", tc.email).First(&account), "finding account")
assert.Equal(t, account.Email.String, tc.email, "Email mismatch")
assert.NotEqual(t, account.UserID, 0, "UserID mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password))
var user database.User
testutils.MustExec(t, db.Where("email = ?", tc.email).First(&user), "finding account")
assert.Equal(t, user.Email.String, tc.email, "Email mismatch")
assert.NotEqual(t, user.ID, 0, "UserID mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte(tc.password))
assert.Equal(t, passwordErr, nil, "Password mismatch")
var user database.User
testutils.MustExec(t, db.Where("id = ?", account.UserID).First(&user), "finding user")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&user), "finding user")
assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch")
// welcome email
@ -140,11 +139,9 @@ func TestJoinError(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
var userCount int64
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
assert.Equal(t, userCount, int64(0), "userCount mismatch")
})
@ -168,11 +165,9 @@ func TestJoinError(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
var userCount int64
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
assert.Equal(t, userCount, int64(0), "userCount mismatch")
})
@ -198,11 +193,9 @@ func TestJoinError(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
var userCount int64
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
assert.Equal(t, userCount, int64(0), "userCount mismatch")
})
}
@ -217,8 +210,7 @@ func TestJoinDuplicateEmail(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
dat := url.Values{}
dat.Set("email", "alice@example.com")
@ -232,15 +224,13 @@ func TestJoinDuplicateEmail(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch")
var accountCount, userCount, verificationTokenCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
var userCount, verificationTokenCount int64
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, db.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
var user database.User
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
assert.Equal(t, accountCount, int64(1), "account count mismatch")
assert.Equal(t, userCount, int64(1), "user count mismatch")
assert.Equal(t, verificationTokenCount, int64(0), "verification_token should not have been created")
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
@ -268,11 +258,9 @@ func TestJoinDisabled(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "status code mismatch")
var accountCount, userCount int64
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
var userCount int64
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, int64(0), "account count mismatch")
assert.Equal(t, userCount, int64(0), "user count mismatch")
}
@ -286,8 +274,7 @@ func TestLogin(t *testing.T) {
a.DB = db
server := MustNewServer(t, &a)
u := testutils.SetupUserData(db)
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
_ = testutils.SetupUserData(db, "alice@example.com", "pass1234")
defer server.Close()
// Execute
@ -346,8 +333,7 @@ func TestLogin(t *testing.T) {
a.DB = db
server := MustNewServer(t, &a)
u := testutils.SetupUserData(db)
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
_ = testutils.SetupUserData(db, "alice@example.com", "pass1234")
defer server.Close()
var req *http.Request
@ -386,8 +372,7 @@ func TestLogin(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
_ = testutils.SetupUserData(db, "alice@example.com", "pass1234")
var req *http.Request
if target == testutils.EndpointWeb {
@ -456,9 +441,8 @@ func TestLogout(t *testing.T) {
a.DB = db
server := MustNewServer(t, &a)
aliceUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, aliceUser, "alice@example.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
aliceUser := testutils.SetupUserData(db, "alice@example.com", "pass1234")
anotherUser := testutils.SetupUserData(db, "bob@example.com", "password123")
session1ExpiresAt := time.Now().Add(time.Hour * 24)
session1 := database.Session{
@ -570,8 +554,7 @@ func TestResetPassword(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
u := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
@ -593,7 +576,7 @@ func TestResetPassword(t *testing.T) {
}
testutils.MustExec(t, db.Save(&s2), "preparing user session 2")
anotherUser := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db, "bob@example.com", "password123")
testutils.MustExec(t, db.Save(&database.Session{
Key: "some-session-key-3",
UserID: anotherUser.ID,
@ -613,12 +596,12 @@ func TestResetPassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch")
var resetToken database.Token
var account database.Account
var user database.User
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
assert.NotEqual(t, resetToken.UsedAt, nil, "reset_token UsedAt mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte("newpassword"))
assert.Equal(t, passwordErr, nil, "Password mismatch")
var s1Count, s2Count int64
@ -646,8 +629,7 @@ func TestResetPassword(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
@ -668,12 +650,12 @@ func TestResetPassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismatch")
var resetToken database.Token
var account database.Account
var user database.User
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
assert.Equal(t, u.Password, user.Password, "password should not have been updated")
assert.Equal(t, u.Password, user.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
})
@ -687,8 +669,7 @@ func TestResetPassword(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
@ -710,10 +691,10 @@ func TestResetPassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusGone, "Status code mismatch")
var resetToken database.Token
var account database.Account
var user database.User
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account")
assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "failed to find account")
assert.Equal(t, u.Password, user.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
})
@ -727,8 +708,7 @@ func TestResetPassword(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
usedAt := time.Now().Add(time.Hour * -11).UTC()
tok := database.Token{
@ -753,10 +733,10 @@ func TestResetPassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismatch")
var resetToken database.Token
var account database.Account
var user database.User
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account")
assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "failed to find account")
assert.Equal(t, u.Password, user.Password, "password should not have been updated")
resetTokenUsedAtUTC := resetToken.UsedAt.UTC()
if resetTokenUsedAtUTC.Year() != usedAt.Year() ||
@ -782,8 +762,7 @@ func TestCreateResetToken(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
// Execute
dat := url.Values{}
@ -816,8 +795,7 @@ func TestCreateResetToken(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
_ = testutils.SetupUserData(db, "alice@example.com", "somepassword")
// Execute
dat := url.Values{}
@ -846,8 +824,7 @@ func TestUpdatePassword(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword")
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
// Execute
dat := url.Values{}
@ -861,10 +838,9 @@ func TestUpdatePassword(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismsatch")
var account database.Account
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&user), "finding account")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte("newpassword"))
assert.Equal(t, passwordErr, nil, "Password mismatch")
})
@ -877,8 +853,7 @@ func TestUpdatePassword(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
u := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
// Execute
dat := url.Values{}
@ -892,9 +867,9 @@ func TestUpdatePassword(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
var account database.Account
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
var user database.User
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated")
})
t.Run("password too short", func(t *testing.T) {
@ -907,8 +882,7 @@ func TestUpdatePassword(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
u := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
// Execute
dat := url.Values{}
@ -922,9 +896,9 @@ func TestUpdatePassword(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
var account database.Account
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
var user database.User
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated")
})
t.Run("password confirmation mismatch", func(t *testing.T) {
@ -937,8 +911,7 @@ func TestUpdatePassword(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
u := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
// Execute
dat := url.Values{}
@ -952,9 +925,9 @@ func TestUpdatePassword(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
var account database.Account
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
var user database.User
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated")
})
}
@ -969,8 +942,7 @@ func TestUpdateEmail(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
u := testutils.SetupUserData(db, "alice@example.com", "pass1234")
// Execute
dat := url.Values{}
@ -984,11 +956,10 @@ func TestUpdateEmail(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch")
var user database.User
var account database.Account
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
assert.Equal(t, account.Email.String, "alice-new@example.com", "email mismatch")
assert.Equal(t, user.Email.String, "alice-new@example.com", "email mismatch")
})
t.Run("password mismatch", func(t *testing.T) {
@ -1001,8 +972,7 @@ func TestUpdateEmail(t *testing.T) {
server := MustNewServer(t, &a)
defer server.Close()
u := testutils.SetupUserData(db)
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
u := testutils.SetupUserData(db, "alice@example.com", "pass1234")
// Execute
dat := url.Values{}
@ -1016,11 +986,9 @@ func TestUpdateEmail(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
var user database.User
var account database.Account
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch")
assert.Equal(t, user.Email.String, "alice@example.com", "email mismatch")
})
}

View file

@ -36,7 +36,6 @@ var (
func InitSchema(db *gorm.DB) {
if err := db.AutoMigrate(
&User{},
&Account{},
&Book{},
&Note{},
&Token{},

View file

@ -61,20 +61,13 @@ type Note struct {
// User is a model for a user
type User struct {
Model
UUID string `json:"uuid" gorm:"type:text;index"`
Account Account `gorm:"foreignKey:UserID"`
UUID string `json:"uuid" gorm:"type:text;index"`
Email NullString `gorm:"index"`
Password NullString `json:"-"`
LastLoginAt *time.Time `json:"-"`
MaxUSN int `json:"-" gorm:"default:0"`
}
// Account is a model for an account
type Account struct {
Model
UserID int `gorm:"index"`
Email NullString
Password NullString
}
// Token is a model for a token
type Token struct {
Model

View file

@ -67,8 +67,6 @@ type AuthParams struct {
// Auth is an authentication middleware
func Auth(db *gorm.DB, next http.HandlerFunc, p *AuthParams) http.HandlerFunc {
next = WithAccount(db, next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok, err := AuthWithSession(db, r)
if !ok {
@ -93,27 +91,6 @@ func Auth(db *gorm.DB, next http.HandlerFunc, p *AuthParams) http.HandlerFunc {
ctx := context.WithUser(r.Context(), &user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func WithAccount(db *gorm.DB, next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := context.User(r.Context())
var account database.Account
err := db.Where("user_id = ?", user.ID).First(&account).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
DoError(w, "account not found", err, http.StatusForbidden)
return
} else if err != nil {
DoError(w, "finding account", err, http.StatusInternalServerError)
return
}
ctx := context.WithAccount(r.Context(), &account)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// TokenAuth is an authentication middleware with token

View file

@ -47,7 +47,7 @@ func TestGuestOnly(t *testing.T) {
})
t.Run("logged in", func(t *testing.T) {
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
req := testutils.MakeReq(server.URL, "GET", "/", "")
res := testutils.HTTPAuthDo(t, db, req, user)
@ -67,8 +67,7 @@ func TestGuestOnly(t *testing.T) {
func TestAuth(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
session := database.Session{
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
@ -175,7 +174,7 @@ func TestAuth(t *testing.T) {
func TestTokenAuth(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
tok := database.Token{
UserID: user.ID,
Type: database.TokenTypeResetPassword,
@ -241,9 +240,8 @@ func TestWithAccount(t *testing.T) {
w.WriteHeader(http.StatusOK)
}
t.Run("user with account", func(t *testing.T) {
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
t.Run("authenticated user", func(t *testing.T) {
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
server := httptest.NewServer(Auth(db, handler, nil))
defer server.Close()
@ -253,17 +251,4 @@ func TestWithAccount(t *testing.T) {
assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
})
t.Run("user without account", func(t *testing.T) {
user := testutils.SetupUserData(db)
// Note: not creating account for this user
server := httptest.NewServer(Auth(db, handler, nil))
defer server.Close()
req := testutils.MakeReq(server.URL, "GET", "/", "")
res := testutils.HTTPAuthDo(t, db, req, user)
assert.Equal(t, res.StatusCode, http.StatusForbidden, "status code mismatch")
})
}

View file

@ -30,8 +30,8 @@ import (
func TestGetNote(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
b1 := database.Book{
UUID: testutils.MustUUID(t),
@ -98,7 +98,7 @@ func TestGetNote(t *testing.T) {
func TestGetNote_nonexistent(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
b1 := database.Book{
UUID: testutils.MustUUID(t),

View file

@ -29,8 +29,8 @@ import (
func TestViewNote(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db)
user := testutils.SetupUserData(db, "user@test.com", "password123")
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
b1 := database.Book{
UUID: testutils.MustUUID(t),

View file

@ -29,9 +29,9 @@ type Session struct {
}
// New returns a new session for the given user
func New(user database.User, account database.Account) Session {
func New(user database.User) Session {
return Session{
UUID: user.UUID,
Email: account.Email.String,
Email: user.Email.String,
}
}

View file

@ -27,33 +27,34 @@ import (
)
func TestNew(t *testing.T) {
u1 := database.User{UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb"}
a1 := database.Account{Email: database.ToNullString("alice@example.com")}
u1 := database.User{
UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb",
Email: database.ToNullString("alice@example.com"),
}
u2 := database.User{UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e"}
a2 := database.Account{Email: database.ToNullString("bob@example.com")}
u2 := database.User{
UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e",
Email: database.ToNullString("bob@example.com"),
}
testCases := []struct {
user database.User
account database.Account
user database.User
}{
{
user: u1,
account: a1,
user: u1,
},
{
user: u2,
account: a2,
user: u2,
},
}
for idx, tc := range testCases {
t.Run(fmt.Sprintf("user %d", idx), func(t *testing.T) {
// Execute
got := New(tc.user, tc.account)
got := New(tc.user)
expected := Session{
UUID: tc.user.UUID,
Email: tc.account.Email.String,
Email: tc.user.Email.String,
}
assert.DeepEqual(t, got, expected, "result mismatch")

View file

@ -82,9 +82,6 @@ func ClearData(db *gorm.DB) {
if err := db.Where("1 = 1").Delete(&database.Session{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear sessions"))
}
if err := db.Where("1 = 1").Delete(&database.Account{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear accounts"))
}
if err := db.Where("1 = 1").Delete(&database.User{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear users"))
}
@ -99,15 +96,22 @@ func MustUUID(t *testing.T) string {
return uuid
}
// SetupUserData creates and returns a new user for testing purposes
func SetupUserData(db *gorm.DB) database.User {
// SetupUserData creates and returns a new user with email and password for testing purposes
func SetupUserData(db *gorm.DB, email, password string) database.User {
uuid, err := helpers.GenUUID()
if err != nil {
panic(errors.Wrap(err, "Failed to generate UUID"))
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
panic(errors.Wrap(err, "Failed to hash password"))
}
user := database.User{
UUID: uuid,
UUID: uuid,
Email: database.ToNullString(email),
Password: database.ToNullString(string(hashedPassword)),
}
if err := db.Save(&user).Error; err != nil {
@ -117,28 +121,6 @@ func SetupUserData(db *gorm.DB) database.User {
return user
}
// SetupAccountData creates and returns a new account for the user
func SetupAccountData(db *gorm.DB, user database.User, email, password string) database.Account {
account := database.Account{
UserID: user.ID,
}
if email != "" {
account.Email = database.ToNullString(email)
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
panic(errors.Wrap(err, "Failed to hash password"))
}
account.Password = database.ToNullString(string(hashedPassword))
if err := db.Save(&account).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare account"))
}
return account
}
// SetupSession creates and returns a new user session
func SetupSession(db *gorm.DB, user database.User) database.Session {
session := database.Session{

View file

@ -42,7 +42,7 @@ func TestCreate(t *testing.T) {
db := testutils.InitMemoryDB(t)
// Set up
u := testutils.SetupUserData(db)
u := testutils.SetupUserData(db, "user@test.com", "password123")
// Execute
tok, err := Create(db, u.ID, tc.kind)

View file

@ -50,9 +50,8 @@ type Alert struct {
type Data struct {
Alert *Alert
// CSRF template.HTML
User *database.User
Account *database.Account
Yield map[string]interface{}
User *database.User
Yield map[string]interface{}
}
func getErrMessage(err error) string {

View file

@ -41,7 +41,7 @@
</div>
<div class="email">
{{.Account.Email.String}}
{{.User.Email.String}}
</div>
</div>

View file

@ -108,14 +108,13 @@ func (v *View) Render(w http.ResponseWriter, r *http.Request, data *Data, status
}
vd.User = context.User(r.Context())
vd.Account = context.Account(r.Context())
// Put user data in Yield
if vd.Yield == nil {
vd.Yield = map[string]interface{}{}
}
if vd.Account != nil {
vd.Yield["Email"] = vd.Account.Email.String
if vd.User != nil {
vd.Yield["Email"] = vd.User.Email.String
}
vd.Yield["CurrentPath"] = r.URL.Path
vd.Yield["Standalone"] = buildinfo.Standalone