mirror of
https://github.com/dnote/dnote
synced 2026-03-14 14:35:50 +01:00
Merge user and account (#701)
This commit is contained in:
parent
b03ca999a5
commit
0a5728faf3
26 changed files with 248 additions and 477 deletions
|
|
@ -136,8 +136,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
// helpers
|
||||
func setupUser(t *testing.T, db *cliDatabase.DB) database.User {
|
||||
user := apitest.SetupUserData(serverDb)
|
||||
apitest.SetupAccountData(serverDb, user, "alice@example.com", "pass1234")
|
||||
user := apitest.SetupUserData(serverDb, "alice@example.com", "pass1234")
|
||||
|
||||
return user
|
||||
}
|
||||
|
|
@ -4255,8 +4254,7 @@ func TestSync_EmptyServer(t *testing.T) {
|
|||
// Step 1: Set up user on Server A and sync
|
||||
apiEndpointA := fmt.Sprintf("%s/api", serverA.URL)
|
||||
|
||||
userA := apitest.SetupUserData(serverDbA)
|
||||
apitest.SetupAccountData(serverDbA, userA, "alice@example.com", "pass1234")
|
||||
userA := apitest.SetupUserData(serverDbA, "alice@example.com", "pass1234")
|
||||
sessionA := apitest.SetupSession(serverDbA, userA)
|
||||
cliDatabase.MustExec(t, "inserting session_key", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKey, sessionA.Key)
|
||||
cliDatabase.MustExec(t, "inserting session_key_expiry", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKeyExpiry, sessionA.ExpiresAt.Unix())
|
||||
|
|
@ -4280,8 +4278,7 @@ func TestSync_EmptyServer(t *testing.T) {
|
|||
apiEndpointB := fmt.Sprintf("%s/api", serverB.URL)
|
||||
|
||||
// Set up user on Server B
|
||||
userB := apitest.SetupUserData(serverDbB)
|
||||
apitest.SetupAccountData(serverDbB, userB, "alice@example.com", "pass1234")
|
||||
userB := apitest.SetupUserData(serverDbB, "alice@example.com", "pass1234")
|
||||
sessionB := apitest.SetupSession(serverDbB, userB)
|
||||
cliDatabase.MustExec(t, "updating session_key for B", ctx.DB, "UPDATE system SET value = ? WHERE key = ?", sessionB.Key, consts.SystemSessionKey)
|
||||
cliDatabase.MustExec(t, "updating session_key_expiry for B", ctx.DB, "UPDATE system SET value = ? WHERE key = ?", sessionB.ExpiresAt.Unix(), consts.SystemSessionKeyExpiry)
|
||||
|
|
|
|||
|
|
@ -56,10 +56,10 @@ func TestCreateBook(t *testing.T) {
|
|||
func() {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
a := NewTest()
|
||||
|
|
@ -122,10 +122,10 @@ func TestDeleteBook(t *testing.T) {
|
|||
func() {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
book := database.Book{UserID: user.ID, Label: "js", Deleted: false}
|
||||
|
|
@ -201,10 +201,10 @@ func TestUpdateBook(t *testing.T) {
|
|||
func() {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ func TestIncremenetUserUSN(t *testing.T) {
|
|||
func() {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
// execute
|
||||
|
|
|
|||
|
|
@ -77,11 +77,11 @@ func TestCreateNote(t *testing.T) {
|
|||
func() {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
fmt.Println(user)
|
||||
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
|
||||
|
|
@ -130,7 +130,7 @@ func TestCreateNote(t *testing.T) {
|
|||
func TestCreateNote_EmptyBody(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
b1 := database.Book{UserID: user.ID, Label: "testBook"}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book")
|
||||
|
||||
|
|
@ -169,10 +169,10 @@ func TestUpdateNote(t *testing.T) {
|
|||
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
|
||||
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
|
||||
|
||||
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
|
||||
|
|
@ -234,7 +234,7 @@ func TestUpdateNote(t *testing.T) {
|
|||
func TestUpdateNote_SameContent(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
b1 := database.Book{UserID: user.ID, Label: "testBook"}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book")
|
||||
|
||||
|
|
@ -291,10 +291,10 @@ func TestDeleteNote(t *testing.T) {
|
|||
func() {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
b1 := database.Book{UserID: user.ID, Label: "testBook"}
|
||||
|
|
@ -351,7 +351,7 @@ func TestDeleteNote(t *testing.T) {
|
|||
func TestGetNotes_FTSSearch(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
b1 := database.Book{UserID: user.ID, Label: "testBook"}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book")
|
||||
|
||||
|
|
@ -415,7 +415,7 @@ func TestGetNotes_FTSSearch(t *testing.T) {
|
|||
func TestGetNotes_FTSSearch_Snippet(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
b1 := database.Book{UserID: user.ID, Label: "testBook"}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book")
|
||||
|
||||
|
|
@ -449,7 +449,7 @@ func TestGetNotes_FTSSearch_Snippet(t *testing.T) {
|
|||
func TestGetNotes_FTSSearch_ShortWord(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
b1 := database.Book{UserID: user.ID, Label: "testBook"}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book")
|
||||
|
||||
|
|
@ -481,7 +481,7 @@ func TestGetNotes_FTSSearch_ShortWord(t *testing.T) {
|
|||
func TestGetNotes_All(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
b1 := database.Book{UserID: user.ID, Label: "testBook"}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book")
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
|
|||
tx := a.DB.Begin()
|
||||
|
||||
var count int64
|
||||
if err := tx.Model(database.Account{}).Where("email = ?", email).Count(&count).Error; err != nil {
|
||||
if err := tx.Model(&database.User{}).Where("email = ?", email).Count(&count).Error; err != nil {
|
||||
return database.User{}, pkgErrors.Wrap(err, "counting user")
|
||||
}
|
||||
if count > 0 {
|
||||
|
|
@ -85,21 +85,14 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
|
|||
}
|
||||
|
||||
user := database.User{
|
||||
UUID: uuid,
|
||||
UUID: uuid,
|
||||
Email: database.ToNullString(email),
|
||||
Password: database.ToNullString(string(hashedPassword)),
|
||||
}
|
||||
if err = tx.Save(&user).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return database.User{}, pkgErrors.Wrap(err, "saving user")
|
||||
}
|
||||
account := database.Account{
|
||||
Email: database.ToNullString(email),
|
||||
Password: database.ToNullString(string(hashedPassword)),
|
||||
UserID: user.ID,
|
||||
}
|
||||
if err = tx.Save(&account).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return database.User{}, pkgErrors.Wrap(err, "saving account")
|
||||
}
|
||||
|
||||
if err := a.TouchLastLoginAt(user, tx); err != nil {
|
||||
tx.Rollback()
|
||||
|
|
@ -111,42 +104,36 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
|
|||
return user, nil
|
||||
}
|
||||
|
||||
// GetAccountByEmail finds an account by email
|
||||
func (a *App) GetAccountByEmail(email string) (*database.Account, error) {
|
||||
var account database.Account
|
||||
err := a.DB.Where("email = ?", email).First(&account).Error
|
||||
// GetUserByEmail finds a user by email
|
||||
func (a *App) GetUserByEmail(email string) (*database.User, error) {
|
||||
var user database.User
|
||||
err := a.DB.Where("email = ?", email).First(&user).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrNotFound
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &account, nil
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// Authenticate authenticates a user
|
||||
func (a *App) Authenticate(email, password string) (*database.User, error) {
|
||||
account, err := a.GetAccountByEmail(email)
|
||||
user, err := a.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(password))
|
||||
err = bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte(password))
|
||||
if err != nil {
|
||||
return nil, ErrLoginInvalid
|
||||
}
|
||||
|
||||
var user database.User
|
||||
err = a.DB.Where("id = ?", account.UserID).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, pkgErrors.Wrap(err, "finding user")
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// UpdateAccountPassword updates an account's password with validation
|
||||
func UpdateAccountPassword(db *gorm.DB, account *database.Account, newPassword string) error {
|
||||
// UpdateUserPassword updates a user's password with validation
|
||||
func UpdateUserPassword(db *gorm.DB, user *database.User, newPassword string) error {
|
||||
// Validate password
|
||||
if err := validatePassword(newPassword); err != nil {
|
||||
return err
|
||||
|
|
@ -159,25 +146,25 @@ func UpdateAccountPassword(db *gorm.DB, account *database.Account, newPassword s
|
|||
}
|
||||
|
||||
// Update the password
|
||||
if err := db.Model(&account).Update("password", string(hashedPassword)).Error; err != nil {
|
||||
if err := db.Model(&user).Update("password", string(hashedPassword)).Error; err != nil {
|
||||
return pkgErrors.Wrap(err, "updating password")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveUser removes a user and their account from the system
|
||||
// RemoveUser removes a user from the system
|
||||
// Returns an error if the user has any notes or books
|
||||
func (a *App) RemoveUser(email string) error {
|
||||
// Find the account and user
|
||||
account, err := a.GetAccountByEmail(email)
|
||||
// Find the user
|
||||
user, err := a.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if user has any notes
|
||||
var noteCount int64
|
||||
if err := a.DB.Model(&database.Note{}).Where("user_id = ? AND deleted = ?", account.UserID, false).Count(¬eCount).Error; err != nil {
|
||||
if err := a.DB.Model(&database.Note{}).Where("user_id = ? AND deleted = ?", user.ID, false).Count(¬eCount).Error; err != nil {
|
||||
return pkgErrors.Wrap(err, "counting notes")
|
||||
}
|
||||
if noteCount > 0 {
|
||||
|
|
@ -186,34 +173,18 @@ func (a *App) RemoveUser(email string) error {
|
|||
|
||||
// Check if user has any books
|
||||
var bookCount int64
|
||||
if err := a.DB.Model(&database.Book{}).Where("user_id = ? AND deleted = ?", account.UserID, false).Count(&bookCount).Error; err != nil {
|
||||
if err := a.DB.Model(&database.Book{}).Where("user_id = ? AND deleted = ?", user.ID, false).Count(&bookCount).Error; err != nil {
|
||||
return pkgErrors.Wrap(err, "counting books")
|
||||
}
|
||||
if bookCount > 0 {
|
||||
return ErrUserHasExistingResources
|
||||
}
|
||||
|
||||
// Delete account and user in a transaction
|
||||
tx := a.DB.Begin()
|
||||
|
||||
if err := tx.Delete(&account).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return pkgErrors.Wrap(err, "deleting account")
|
||||
}
|
||||
|
||||
var user database.User
|
||||
if err := tx.Where("id = ?", account.UserID).First(&user).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return pkgErrors.Wrap(err, "finding user")
|
||||
}
|
||||
|
||||
if err := tx.Delete(&user).Error; err != nil {
|
||||
tx.Rollback()
|
||||
// Delete user
|
||||
if err := a.DB.Delete(&user).Error; err != nil {
|
||||
return pkgErrors.Wrap(err, "deleting user")
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -82,21 +82,20 @@ func TestCreateUser_ProValue(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
func TestGetAccountByEmail(t *testing.T) {
|
||||
func TestGetUserByEmail(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "password123")
|
||||
|
||||
a := NewTest()
|
||||
a.DB = db
|
||||
|
||||
account, err := a.GetAccountByEmail("alice@example.com")
|
||||
foundUser, err := a.GetUserByEmail("alice@example.com")
|
||||
|
||||
assert.Equal(t, err, nil, "should not error")
|
||||
assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch")
|
||||
assert.Equal(t, account.UserID, user.ID, "user ID mismatch")
|
||||
assert.Equal(t, foundUser.Email.String, "alice@example.com", "email mismatch")
|
||||
assert.Equal(t, foundUser.ID, user.ID, "user ID mismatch")
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
|
|
@ -105,10 +104,10 @@ func TestGetAccountByEmail(t *testing.T) {
|
|||
a := NewTest()
|
||||
a.DB = db
|
||||
|
||||
account, err := a.GetAccountByEmail("nonexistent@example.com")
|
||||
user, err := a.GetUserByEmail("nonexistent@example.com")
|
||||
|
||||
assert.Equal(t, err, ErrNotFound, "should return ErrNotFound")
|
||||
assert.Equal(t, account, (*database.Account)(nil), "account should be nil")
|
||||
assert.Equal(t, user, (*database.User)(nil), "user should be nil")
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -124,25 +123,21 @@ func TestCreateUser(t *testing.T) {
|
|||
|
||||
var userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
assert.Equal(t, userCount, int64(1), "book count mismatch")
|
||||
assert.Equal(t, userCount, int64(1), "user count mismatch")
|
||||
|
||||
var accountCount int64
|
||||
var accountRecord database.Account
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
testutils.MustExec(t, db.First(&accountRecord), "finding account")
|
||||
var userRecord database.User
|
||||
testutils.MustExec(t, db.First(&userRecord), "finding user")
|
||||
|
||||
assert.Equal(t, accountCount, int64(1), "account count mismatch")
|
||||
assert.Equal(t, accountRecord.Email.String, "alice@example.com", "account email mismatch")
|
||||
assert.Equal(t, userRecord.Email.String, "alice@example.com", "user email mismatch")
|
||||
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(accountRecord.Password.String), []byte("pass1234"))
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(userRecord.Password.String), []byte("pass1234"))
|
||||
assert.Equal(t, passwordErr, nil, "Password mismatch")
|
||||
})
|
||||
|
||||
t.Run("duplicate email", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
aliceUser := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, aliceUser, "alice@example.com", "somepassword")
|
||||
testutils.SetupUserData(db, "alice@example.com", "somepassword")
|
||||
|
||||
a := NewTest()
|
||||
a.DB = db
|
||||
|
|
@ -150,116 +145,109 @@ func TestCreateUser(t *testing.T) {
|
|||
|
||||
assert.Equal(t, err, ErrDuplicateEmail, "error mismatch")
|
||||
|
||||
var userCount, accountCount int64
|
||||
var userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
|
||||
assert.Equal(t, userCount, int64(1), "user count mismatch")
|
||||
assert.Equal(t, accountCount, int64(1), "account count mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateAccountPassword(t *testing.T) {
|
||||
func TestUpdateUserPassword(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
|
||||
|
||||
err := UpdateAccountPassword(db, &account, "newpassword123")
|
||||
err := UpdateUserPassword(db, &user, "newpassword123")
|
||||
|
||||
assert.Equal(t, err, nil, "should not error")
|
||||
|
||||
// Verify password was updated in database
|
||||
var updatedAccount database.Account
|
||||
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&updatedAccount), "finding updated account")
|
||||
var updatedUser database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&updatedUser), "finding updated user")
|
||||
|
||||
// Verify new password works
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123"))
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123"))
|
||||
assert.Equal(t, passwordErr, nil, "New password should match")
|
||||
|
||||
// Verify old password no longer works
|
||||
oldPasswordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("oldpassword123"))
|
||||
oldPasswordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("oldpassword123"))
|
||||
assert.NotEqual(t, oldPasswordErr, nil, "Old password should not match")
|
||||
})
|
||||
|
||||
t.Run("password too short", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
|
||||
|
||||
err := UpdateAccountPassword(db, &account, "short")
|
||||
err := UpdateUserPassword(db, &user, "short")
|
||||
|
||||
assert.Equal(t, err, ErrPasswordTooShort, "should return ErrPasswordTooShort")
|
||||
|
||||
// Verify password was NOT updated in database
|
||||
var unchangedAccount database.Account
|
||||
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account")
|
||||
var unchangedUser database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user")
|
||||
|
||||
// Verify old password still works
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123"))
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123"))
|
||||
assert.Equal(t, passwordErr, nil, "Old password should still match")
|
||||
})
|
||||
|
||||
t.Run("empty password", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
|
||||
|
||||
err := UpdateAccountPassword(db, &account, "")
|
||||
err := UpdateUserPassword(db, &user, "")
|
||||
|
||||
assert.Equal(t, err, ErrPasswordTooShort, "should return ErrPasswordTooShort")
|
||||
|
||||
// Verify password was NOT updated in database
|
||||
var unchangedAccount database.Account
|
||||
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account")
|
||||
var unchangedUser database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user")
|
||||
|
||||
// Verify old password still works
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123"))
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123"))
|
||||
assert.Equal(t, passwordErr, nil, "Old password should still match")
|
||||
})
|
||||
|
||||
t.Run("transaction rollback", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
|
||||
|
||||
// Start a transaction and rollback to verify UpdateAccountPassword respects transactions
|
||||
// Start a transaction and rollback to verify UpdateUserPassword respects transactions
|
||||
tx := db.Begin()
|
||||
err := UpdateAccountPassword(tx, &account, "newpassword123")
|
||||
err := UpdateUserPassword(tx, &user, "newpassword123")
|
||||
assert.Equal(t, err, nil, "should not error")
|
||||
tx.Rollback()
|
||||
|
||||
// Verify password was NOT updated after rollback
|
||||
var unchangedAccount database.Account
|
||||
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account")
|
||||
var unchangedUser database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user")
|
||||
|
||||
// Verify old password still works
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123"))
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123"))
|
||||
assert.Equal(t, passwordErr, nil, "Old password should still match after rollback")
|
||||
})
|
||||
|
||||
t.Run("transaction commit", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123")
|
||||
|
||||
// Start a transaction and commit to verify UpdateAccountPassword respects transactions
|
||||
// Start a transaction and commit to verify UpdateUserPassword respects transactions
|
||||
tx := db.Begin()
|
||||
err := UpdateAccountPassword(tx, &account, "newpassword123")
|
||||
err := UpdateUserPassword(tx, &user, "newpassword123")
|
||||
assert.Equal(t, err, nil, "should not error")
|
||||
tx.Commit()
|
||||
|
||||
// Verify password was updated after commit
|
||||
var updatedAccount database.Account
|
||||
testutils.MustExec(t, db.Where("id = ?", account.ID).First(&updatedAccount), "finding updated account")
|
||||
var updatedUser database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&updatedUser), "finding updated user")
|
||||
|
||||
// Verify new password works
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123"))
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123"))
|
||||
assert.Equal(t, passwordErr, nil, "New password should match after commit")
|
||||
})
|
||||
}
|
||||
|
|
@ -268,8 +256,7 @@ func TestRemoveUser(t *testing.T) {
|
|||
t.Run("success", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
|
||||
testutils.SetupUserData(db, "alice@example.com", "password123")
|
||||
|
||||
a := NewTest()
|
||||
a.DB = db
|
||||
|
|
@ -282,11 +269,6 @@ func TestRemoveUser(t *testing.T) {
|
|||
var userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users")
|
||||
assert.Equal(t, userCount, int64(0), "user should be deleted")
|
||||
|
||||
// Verify account was deleted
|
||||
var accountCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts")
|
||||
assert.Equal(t, accountCount, int64(0), "account should be deleted")
|
||||
})
|
||||
|
||||
t.Run("user not found", func(t *testing.T) {
|
||||
|
|
@ -303,8 +285,7 @@ func TestRemoveUser(t *testing.T) {
|
|||
t.Run("user has notes", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "password123")
|
||||
|
||||
book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false}
|
||||
testutils.MustExec(t, db.Save(&book), "creating book")
|
||||
|
|
@ -324,17 +305,12 @@ func TestRemoveUser(t *testing.T) {
|
|||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users")
|
||||
assert.Equal(t, userCount, int64(1), "user should not be deleted")
|
||||
|
||||
// Verify account was NOT deleted
|
||||
var accountCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts")
|
||||
assert.Equal(t, accountCount, int64(1), "account should not be deleted")
|
||||
})
|
||||
|
||||
t.Run("user has books", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "password123")
|
||||
|
||||
book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false}
|
||||
testutils.MustExec(t, db.Save(&book), "creating book")
|
||||
|
|
@ -351,17 +327,12 @@ func TestRemoveUser(t *testing.T) {
|
|||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users")
|
||||
assert.Equal(t, userCount, int64(1), "user should not be deleted")
|
||||
|
||||
// Verify account was NOT deleted
|
||||
var accountCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts")
|
||||
assert.Equal(t, accountCount, int64(1), "account should not be deleted")
|
||||
})
|
||||
|
||||
t.Run("user has deleted notes and books", func(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@example.com", "password123")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "password123")
|
||||
|
||||
book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false}
|
||||
testutils.MustExec(t, db.Save(&book), "creating book")
|
||||
|
|
@ -385,9 +356,5 @@ func TestRemoveUser(t *testing.T) {
|
|||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users")
|
||||
assert.Equal(t, userCount, int64(0), "user should be deleted")
|
||||
|
||||
// Verify account was deleted
|
||||
var accountCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts")
|
||||
assert.Equal(t, accountCount, int64(0), "account should be deleted")
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,12 +81,12 @@ func userRemoveCmd(args []string, stdin io.Reader) {
|
|||
defer cleanup()
|
||||
|
||||
// Check if user exists first
|
||||
_, err := a.GetAccountByEmail(*email)
|
||||
_, err := a.GetUserByEmail(*email)
|
||||
if err != nil {
|
||||
if errors.Is(err, app.ErrNotFound) {
|
||||
fmt.Printf("Error: user with email %s not found\n", *email)
|
||||
} else {
|
||||
log.ErrorWrap(err, "finding account")
|
||||
log.ErrorWrap(err, "finding user")
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
|
@ -133,19 +133,19 @@ func userResetPasswordCmd(args []string) {
|
|||
a, cleanup := setupAppWithDB(fs, *dbPath)
|
||||
defer cleanup()
|
||||
|
||||
// Find the account
|
||||
account, err := a.GetAccountByEmail(*email)
|
||||
// Find the user
|
||||
user, err := a.GetUserByEmail(*email)
|
||||
if err != nil {
|
||||
if errors.Is(err, app.ErrNotFound) {
|
||||
fmt.Printf("Error: user with email %s not found\n", *email)
|
||||
} else {
|
||||
log.ErrorWrap(err, "finding account")
|
||||
log.ErrorWrap(err, "finding user")
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Update the password
|
||||
if err := app.UpdateAccountPassword(a.DB, account, *password); err != nil {
|
||||
if err := app.UpdateUserPassword(a.DB, user, *password); err != nil {
|
||||
log.ErrorWrap(err, "updating password")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ func TestUserCreateCmd(t *testing.T) {
|
|||
testutils.MustExec(t, db.Model(&database.User{}).Count(&count), "counting users")
|
||||
assert.Equal(t, count, int64(1), "should have 1 user")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("email = ?", "test@example.com").First(&account), "finding account")
|
||||
assert.Equal(t, account.Email.String, "test@example.com", "email mismatch")
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("email = ?", "test@example.com").First(&user), "finding user")
|
||||
assert.Equal(t, user.Email.String, "test@example.com", "email mismatch")
|
||||
}
|
||||
|
||||
func TestUserRemoveCmd(t *testing.T) {
|
||||
|
|
@ -55,8 +55,7 @@ func TestUserRemoveCmd(t *testing.T) {
|
|||
|
||||
// Create a user first
|
||||
db := testutils.InitDB(tmpDB)
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "test@example.com", "password123")
|
||||
testutils.SetupUserData(db, "test@example.com", "password123")
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
|
|
@ -81,9 +80,8 @@ func TestUserResetPasswordCmd(t *testing.T) {
|
|||
|
||||
// Create a user first
|
||||
db := testutils.InitDB(tmpDB)
|
||||
user := testutils.SetupUserData(db)
|
||||
account := testutils.SetupAccountData(db, user, "test@example.com", "oldpassword123")
|
||||
oldPasswordHash := account.Password.String
|
||||
user := testutils.SetupUserData(db, "test@example.com", "oldpassword123")
|
||||
oldPasswordHash := user.Password.String
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
|
|
@ -97,18 +95,18 @@ func TestUserResetPasswordCmd(t *testing.T) {
|
|||
sqlDB2.Close()
|
||||
}()
|
||||
|
||||
var updatedAccount database.Account
|
||||
testutils.MustExec(t, db2.Where("email = ?", "test@example.com").First(&updatedAccount), "finding account")
|
||||
var updatedUser database.User
|
||||
testutils.MustExec(t, db2.Where("email = ?", "test@example.com").First(&updatedUser), "finding user")
|
||||
|
||||
// Verify password hash changed
|
||||
assert.Equal(t, updatedAccount.Password.String != oldPasswordHash, true, "password hash should be different")
|
||||
assert.Equal(t, len(updatedAccount.Password.String) > 0, true, "password should be set")
|
||||
assert.Equal(t, updatedUser.Password.String != oldPasswordHash, true, "password hash should be different")
|
||||
assert.Equal(t, len(updatedUser.Password.String) > 0, true, "password should be set")
|
||||
|
||||
// Verify new password works
|
||||
err := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123"))
|
||||
err := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123"))
|
||||
assert.Equal(t, err, nil, "new password should match")
|
||||
|
||||
// Verify old password doesn't work
|
||||
err = bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("oldpassword123"))
|
||||
err = bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("oldpassword123"))
|
||||
assert.Equal(t, err != nil, true, "old password should not match")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,9 +25,8 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
userKey privateKey = "user"
|
||||
accountKey privateKey = "account"
|
||||
tokenKey privateKey = "token"
|
||||
userKey privateKey = "user"
|
||||
tokenKey privateKey = "token"
|
||||
)
|
||||
|
||||
type privateKey string
|
||||
|
|
@ -37,11 +36,6 @@ func WithUser(ctx context.Context, user *database.User) context.Context {
|
|||
return context.WithValue(ctx, userKey, user)
|
||||
}
|
||||
|
||||
// WithAccount creates a new context with the given account
|
||||
func WithAccount(ctx context.Context, account *database.Account) context.Context {
|
||||
return context.WithValue(ctx, accountKey, account)
|
||||
}
|
||||
|
||||
// WithToken creates a new context with the given user
|
||||
func WithToken(ctx context.Context, tok *database.Token) context.Context {
|
||||
return context.WithValue(ctx, tokenKey, tok)
|
||||
|
|
@ -59,17 +53,6 @@ func User(ctx context.Context) *database.User {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Account retrieves an account from the given context.
|
||||
func Account(ctx context.Context) *database.Account {
|
||||
if temp := ctx.Value(accountKey); temp != nil {
|
||||
if account, ok := temp.(*database.Account); ok {
|
||||
return account
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Token retrieves a token from the given context.
|
||||
func Token(ctx context.Context) *database.Token {
|
||||
if temp := ctx.Value(tokenKey); temp != nil {
|
||||
|
|
|
|||
|
|
@ -50,10 +50,8 @@ func TestGetBooks(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
@ -143,10 +141,8 @@ func TestGetBooksByName(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
@ -212,10 +208,8 @@ func TestGetBook(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
@ -276,10 +270,8 @@ func TestGetBookNonOwner(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
nonOwner := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, nonOwner, "bob@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
nonOwner := testutils.SetupUserData(db, "bob@test.com", "pass1234")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
@ -314,8 +306,7 @@ func TestCreateBook(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`)
|
||||
|
|
@ -375,8 +366,7 @@ func TestCreateBook(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
|
|
@ -465,8 +455,7 @@ func TestUpdateBook(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
|
|
@ -550,11 +539,9 @@ func TestDeleteBook(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn")
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
|
|
|
|||
|
|
@ -63,10 +63,8 @@ func TestGetNotes(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
@ -187,10 +185,8 @@ func TestGetNote(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "user@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, anotherUser, "another@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "user@test.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "pass1234")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
@ -318,8 +314,7 @@ func TestCreateNote(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
|
|
@ -400,8 +395,7 @@ func TestDeleteNote(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
|
|
@ -559,8 +553,7 @@ func TestUpdateNote(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
|
|
|
|||
|
|
@ -307,23 +307,23 @@ func (u *Users) CreateResetToken(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var account database.Account
|
||||
err := u.app.DB.Where("email = ?", form.Email).First(&account).Error
|
||||
var user database.User
|
||||
err := u.app.DB.Where("email = ?", form.Email).First(&user).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
handleHTMLError(w, r, err, "finding account", u.PasswordResetView, vd)
|
||||
handleHTMLError(w, r, err, "finding user", u.PasswordResetView, vd)
|
||||
return
|
||||
}
|
||||
|
||||
resetToken, err := token.Create(u.app.DB, account.UserID, database.TokenTypeResetPassword)
|
||||
resetToken, err := token.Create(u.app.DB, user.ID, database.TokenTypeResetPassword)
|
||||
if err != nil {
|
||||
handleHTMLError(w, r, err, "generating token", u.PasswordResetView, vd)
|
||||
return
|
||||
}
|
||||
|
||||
if err := u.app.SendPasswordResetEmail(account.Email.String, resetToken.Value); err != nil {
|
||||
if err := u.app.SendPasswordResetEmail(user.Email.String, resetToken.Value); err != nil {
|
||||
handleHTMLError(w, r, err, "sending password reset email", u.PasswordResetView, vd)
|
||||
return
|
||||
}
|
||||
|
|
@ -396,8 +396,8 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var account database.Account
|
||||
if err := u.app.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
|
||||
var user database.User
|
||||
if err := u.app.DB.Where("id = ?", token.UserID).First(&user).Error; err != nil {
|
||||
handleHTMLError(w, r, err, "finding user", u.PasswordResetConfirmView, vd)
|
||||
return
|
||||
}
|
||||
|
|
@ -405,7 +405,7 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) {
|
|||
tx := u.app.DB.Begin()
|
||||
|
||||
// Update the password
|
||||
if err := app.UpdateAccountPassword(tx, &account, params.Password); err != nil {
|
||||
if err := app.UpdateUserPassword(tx, &user, params.Password); err != nil {
|
||||
tx.Rollback()
|
||||
handleHTMLError(w, r, err, "updating password", u.PasswordResetConfirmView, vd)
|
||||
return
|
||||
|
|
@ -417,7 +417,7 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := u.app.DeleteUserSessions(tx, account.UserID); err != nil {
|
||||
if err := u.app.DeleteUserSessions(tx, user.ID); err != nil {
|
||||
tx.Rollback()
|
||||
handleHTMLError(w, r, err, "deleting user sessions", u.PasswordResetConfirmView, vd)
|
||||
return
|
||||
|
|
@ -425,19 +425,13 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
tx.Commit()
|
||||
|
||||
var user database.User
|
||||
if err := u.app.DB.Where("id = ?", account.UserID).First(&user).Error; err != nil {
|
||||
handleHTMLError(w, r, err, "finding user", u.PasswordResetConfirmView, vd)
|
||||
return
|
||||
}
|
||||
|
||||
alert := views.Alert{
|
||||
Level: views.AlertLvlSuccess,
|
||||
Message: "Password reset successful",
|
||||
}
|
||||
views.RedirectAlert(w, r, "/login", http.StatusFound, alert)
|
||||
|
||||
if err := u.app.SendPasswordResetAlertEmail(account.Email.String); err != nil {
|
||||
if err := u.app.SendPasswordResetAlertEmail(user.Email.String); err != nil {
|
||||
log.ErrorWrap(err, "sending password reset email")
|
||||
}
|
||||
}
|
||||
|
|
@ -493,14 +487,8 @@ func (u *Users) PasswordUpdate(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var account database.Account
|
||||
if err := u.app.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
handleHTMLError(w, r, err, "getting account", u.SettingView, vd)
|
||||
return
|
||||
}
|
||||
|
||||
password := []byte(form.OldPassword)
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(account.Password.String), password); err != nil {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password.String), password); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": user.ID,
|
||||
}).Warn("invalid password update attempt")
|
||||
|
|
@ -508,7 +496,7 @@ func (u *Users) PasswordUpdate(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := app.UpdateAccountPassword(u.app.DB, &account, form.NewPassword); err != nil {
|
||||
if err := app.UpdateUserPassword(u.app.DB, user, form.NewPassword); err != nil {
|
||||
handleHTMLError(w, r, err, "updating password", u.SettingView, vd)
|
||||
return
|
||||
}
|
||||
|
|
@ -534,12 +522,6 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
var account database.Account
|
||||
if err := u.app.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
handleHTMLError(w, r, err, "getting account", u.SettingView, vd)
|
||||
return
|
||||
}
|
||||
|
||||
var form updateProfileForm
|
||||
if err := parseRequestData(r, &form); err != nil {
|
||||
handleHTMLError(w, r, err, "parsing payload", u.SettingView, vd)
|
||||
|
|
@ -547,7 +529,7 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
password := []byte(form.Password)
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(account.Password.String), password); err != nil {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password.String), password); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"user_id": user.ID,
|
||||
}).Warn("invalid email update attempt")
|
||||
|
|
@ -561,23 +543,13 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
tx := u.app.DB.Begin()
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
tx.Rollback()
|
||||
user.Email.String = form.Email
|
||||
|
||||
if err := u.app.DB.Save(&user).Error; err != nil {
|
||||
handleHTMLError(w, r, err, "saving user", u.SettingView, vd)
|
||||
return
|
||||
}
|
||||
|
||||
account.Email.String = form.Email
|
||||
|
||||
if err := tx.Save(&account).Error; err != nil {
|
||||
tx.Rollback()
|
||||
handleHTMLError(w, r, err, "saving account", u.SettingView, vd)
|
||||
return
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
alert := views.Alert{
|
||||
Level: views.AlertLvlSuccess,
|
||||
Message: "Email change successful",
|
||||
|
|
|
|||
|
|
@ -98,15 +98,14 @@ func TestJoin(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusFound, "")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("email = ?", tc.email).First(&account), "finding account")
|
||||
assert.Equal(t, account.Email.String, tc.email, "Email mismatch")
|
||||
assert.NotEqual(t, account.UserID, 0, "UserID mismatch")
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password))
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("email = ?", tc.email).First(&user), "finding account")
|
||||
assert.Equal(t, user.Email.String, tc.email, "Email mismatch")
|
||||
assert.NotEqual(t, user.ID, 0, "UserID mismatch")
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte(tc.password))
|
||||
assert.Equal(t, passwordErr, nil, "Password mismatch")
|
||||
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", account.UserID).First(&user), "finding user")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&user), "finding user")
|
||||
assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch")
|
||||
|
||||
// welcome email
|
||||
|
|
@ -140,11 +139,9 @@ func TestJoinError(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
|
||||
|
||||
var accountCount, userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
var userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
|
||||
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
|
||||
assert.Equal(t, userCount, int64(0), "userCount mismatch")
|
||||
})
|
||||
|
||||
|
|
@ -168,11 +165,9 @@ func TestJoinError(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
|
||||
|
||||
var accountCount, userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
var userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
|
||||
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
|
||||
assert.Equal(t, userCount, int64(0), "userCount mismatch")
|
||||
})
|
||||
|
||||
|
|
@ -198,11 +193,9 @@ func TestJoinError(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
|
||||
|
||||
var accountCount, userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
var userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
|
||||
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
|
||||
assert.Equal(t, userCount, int64(0), "userCount mismatch")
|
||||
})
|
||||
}
|
||||
|
|
@ -217,8 +210,7 @@ func TestJoinDuplicateEmail(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
|
||||
|
||||
dat := url.Values{}
|
||||
dat.Set("email", "alice@example.com")
|
||||
|
|
@ -232,15 +224,13 @@ func TestJoinDuplicateEmail(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch")
|
||||
|
||||
var accountCount, userCount, verificationTokenCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
var userCount, verificationTokenCount int64
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
|
||||
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
|
||||
assert.Equal(t, accountCount, int64(1), "account count mismatch")
|
||||
assert.Equal(t, userCount, int64(1), "user count mismatch")
|
||||
assert.Equal(t, verificationTokenCount, int64(0), "verification_token should not have been created")
|
||||
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
|
||||
|
|
@ -268,11 +258,9 @@ func TestJoinDisabled(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusNotFound, "status code mismatch")
|
||||
|
||||
var accountCount, userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
var userCount int64
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
|
||||
assert.Equal(t, accountCount, int64(0), "account count mismatch")
|
||||
assert.Equal(t, userCount, int64(0), "user count mismatch")
|
||||
}
|
||||
|
||||
|
|
@ -286,8 +274,7 @@ func TestLogin(t *testing.T) {
|
|||
a.DB = db
|
||||
server := MustNewServer(t, &a)
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
|
||||
_ = testutils.SetupUserData(db, "alice@example.com", "pass1234")
|
||||
defer server.Close()
|
||||
|
||||
// Execute
|
||||
|
|
@ -346,8 +333,7 @@ func TestLogin(t *testing.T) {
|
|||
a.DB = db
|
||||
server := MustNewServer(t, &a)
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
|
||||
_ = testutils.SetupUserData(db, "alice@example.com", "pass1234")
|
||||
defer server.Close()
|
||||
|
||||
var req *http.Request
|
||||
|
|
@ -386,8 +372,7 @@ func TestLogin(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
|
||||
_ = testutils.SetupUserData(db, "alice@example.com", "pass1234")
|
||||
|
||||
var req *http.Request
|
||||
if target == testutils.EndpointWeb {
|
||||
|
|
@ -456,9 +441,8 @@ func TestLogout(t *testing.T) {
|
|||
a.DB = db
|
||||
server := MustNewServer(t, &a)
|
||||
|
||||
aliceUser := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, aliceUser, "alice@example.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
aliceUser := testutils.SetupUserData(db, "alice@example.com", "pass1234")
|
||||
anotherUser := testutils.SetupUserData(db, "bob@example.com", "password123")
|
||||
|
||||
session1ExpiresAt := time.Now().Add(time.Hour * 24)
|
||||
session1 := database.Session{
|
||||
|
|
@ -570,8 +554,7 @@ func TestResetPassword(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
|
||||
tok := database.Token{
|
||||
UserID: u.ID,
|
||||
Value: "MivFxYiSMMA4An9dP24DNQ==",
|
||||
|
|
@ -593,7 +576,7 @@ func TestResetPassword(t *testing.T) {
|
|||
}
|
||||
testutils.MustExec(t, db.Save(&s2), "preparing user session 2")
|
||||
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db, "bob@example.com", "password123")
|
||||
testutils.MustExec(t, db.Save(&database.Session{
|
||||
Key: "some-session-key-3",
|
||||
UserID: anotherUser.ID,
|
||||
|
|
@ -613,12 +596,12 @@ func TestResetPassword(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch")
|
||||
|
||||
var resetToken database.Token
|
||||
var account database.Account
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
|
||||
testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
|
||||
|
||||
assert.NotEqual(t, resetToken.UsedAt, nil, "reset_token UsedAt mismatch")
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte("newpassword"))
|
||||
assert.Equal(t, passwordErr, nil, "Password mismatch")
|
||||
|
||||
var s1Count, s2Count int64
|
||||
|
|
@ -646,8 +629,7 @@ func TestResetPassword(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
|
||||
tok := database.Token{
|
||||
UserID: u.ID,
|
||||
Value: "MivFxYiSMMA4An9dP24DNQ==",
|
||||
|
|
@ -668,12 +650,12 @@ func TestResetPassword(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismatch")
|
||||
|
||||
var resetToken database.Token
|
||||
var account database.Account
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
|
||||
testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
|
||||
|
||||
assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
|
||||
assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
|
||||
assert.Equal(t, u.Password, user.Password, "password should not have been updated")
|
||||
assert.Equal(t, u.Password, user.Password, "password should not have been updated")
|
||||
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
|
||||
})
|
||||
|
||||
|
|
@ -687,8 +669,7 @@ func TestResetPassword(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
|
||||
tok := database.Token{
|
||||
UserID: u.ID,
|
||||
Value: "MivFxYiSMMA4An9dP24DNQ==",
|
||||
|
|
@ -710,10 +691,10 @@ func TestResetPassword(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusGone, "Status code mismatch")
|
||||
|
||||
var resetToken database.Token
|
||||
var account database.Account
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
|
||||
testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account")
|
||||
assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "failed to find account")
|
||||
assert.Equal(t, u.Password, user.Password, "password should not have been updated")
|
||||
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
|
||||
})
|
||||
|
||||
|
|
@ -727,8 +708,7 @@ func TestResetPassword(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
|
||||
|
||||
usedAt := time.Now().Add(time.Hour * -11).UTC()
|
||||
tok := database.Token{
|
||||
|
|
@ -753,10 +733,10 @@ func TestResetPassword(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismatch")
|
||||
|
||||
var resetToken database.Token
|
||||
var account database.Account
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
|
||||
testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account")
|
||||
assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "failed to find account")
|
||||
assert.Equal(t, u.Password, user.Password, "password should not have been updated")
|
||||
|
||||
resetTokenUsedAtUTC := resetToken.UsedAt.UTC()
|
||||
if resetTokenUsedAtUTC.Year() != usedAt.Year() ||
|
||||
|
|
@ -782,8 +762,7 @@ func TestCreateResetToken(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "somepassword")
|
||||
|
||||
// Execute
|
||||
dat := url.Values{}
|
||||
|
|
@ -816,8 +795,7 @@ func TestCreateResetToken(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
|
||||
_ = testutils.SetupUserData(db, "alice@example.com", "somepassword")
|
||||
|
||||
// Execute
|
||||
dat := url.Values{}
|
||||
|
|
@ -846,8 +824,7 @@ func TestUpdatePassword(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword")
|
||||
user := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
|
||||
|
||||
// Execute
|
||||
dat := url.Values{}
|
||||
|
|
@ -861,10 +838,9 @@ func TestUpdatePassword(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismsatch")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&user), "finding account")
|
||||
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte("newpassword"))
|
||||
assert.Equal(t, passwordErr, nil, "Password mismatch")
|
||||
})
|
||||
|
||||
|
|
@ -877,8 +853,7 @@ func TestUpdatePassword(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
|
||||
|
||||
// Execute
|
||||
dat := url.Values{}
|
||||
|
|
@ -892,9 +867,9 @@ func TestUpdatePassword(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
|
||||
assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated")
|
||||
})
|
||||
|
||||
t.Run("password too short", func(t *testing.T) {
|
||||
|
|
@ -907,8 +882,7 @@ func TestUpdatePassword(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
|
||||
|
||||
// Execute
|
||||
dat := url.Values{}
|
||||
|
|
@ -922,9 +896,9 @@ func TestUpdatePassword(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
|
||||
assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated")
|
||||
})
|
||||
|
||||
t.Run("password confirmation mismatch", func(t *testing.T) {
|
||||
|
|
@ -937,8 +911,7 @@ func TestUpdatePassword(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "oldpassword")
|
||||
|
||||
// Execute
|
||||
dat := url.Values{}
|
||||
|
|
@ -952,9 +925,9 @@ func TestUpdatePassword(t *testing.T) {
|
|||
// Test
|
||||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
|
||||
assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated")
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -969,8 +942,7 @@ func TestUpdateEmail(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "pass1234")
|
||||
|
||||
// Execute
|
||||
dat := url.Values{}
|
||||
|
|
@ -984,11 +956,10 @@ func TestUpdateEmail(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch")
|
||||
|
||||
var user database.User
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account")
|
||||
|
||||
assert.Equal(t, account.Email.String, "alice-new@example.com", "email mismatch")
|
||||
assert.Equal(t, user.Email.String, "alice-new@example.com", "email mismatch")
|
||||
})
|
||||
|
||||
t.Run("password mismatch", func(t *testing.T) {
|
||||
|
|
@ -1001,8 +972,7 @@ func TestUpdateEmail(t *testing.T) {
|
|||
server := MustNewServer(t, &a)
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
|
||||
u := testutils.SetupUserData(db, "alice@example.com", "pass1234")
|
||||
|
||||
// Execute
|
||||
dat := url.Values{}
|
||||
|
|
@ -1016,11 +986,9 @@ func TestUpdateEmail(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
|
||||
|
||||
var user database.User
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
|
||||
assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch")
|
||||
assert.Equal(t, user.Email.String, "alice@example.com", "email mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ var (
|
|||
func InitSchema(db *gorm.DB) {
|
||||
if err := db.AutoMigrate(
|
||||
&User{},
|
||||
&Account{},
|
||||
&Book{},
|
||||
&Note{},
|
||||
&Token{},
|
||||
|
|
|
|||
|
|
@ -61,20 +61,13 @@ type Note struct {
|
|||
// User is a model for a user
|
||||
type User struct {
|
||||
Model
|
||||
UUID string `json:"uuid" gorm:"type:text;index"`
|
||||
Account Account `gorm:"foreignKey:UserID"`
|
||||
UUID string `json:"uuid" gorm:"type:text;index"`
|
||||
Email NullString `gorm:"index"`
|
||||
Password NullString `json:"-"`
|
||||
LastLoginAt *time.Time `json:"-"`
|
||||
MaxUSN int `json:"-" gorm:"default:0"`
|
||||
}
|
||||
|
||||
// Account is a model for an account
|
||||
type Account struct {
|
||||
Model
|
||||
UserID int `gorm:"index"`
|
||||
Email NullString
|
||||
Password NullString
|
||||
}
|
||||
|
||||
// Token is a model for a token
|
||||
type Token struct {
|
||||
Model
|
||||
|
|
|
|||
|
|
@ -67,8 +67,6 @@ type AuthParams struct {
|
|||
|
||||
// Auth is an authentication middleware
|
||||
func Auth(db *gorm.DB, next http.HandlerFunc, p *AuthParams) http.HandlerFunc {
|
||||
next = WithAccount(db, next)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok, err := AuthWithSession(db, r)
|
||||
if !ok {
|
||||
|
|
@ -93,27 +91,6 @@ func Auth(db *gorm.DB, next http.HandlerFunc, p *AuthParams) http.HandlerFunc {
|
|||
ctx := context.WithUser(r.Context(), &user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func WithAccount(db *gorm.DB, next http.HandlerFunc) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := context.User(r.Context())
|
||||
|
||||
var account database.Account
|
||||
err := db.Where("user_id = ?", user.ID).First(&account).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
DoError(w, "account not found", err, http.StatusForbidden)
|
||||
return
|
||||
} else if err != nil {
|
||||
DoError(w, "finding account", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithAccount(r.Context(), &account)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// TokenAuth is an authentication middleware with token
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ func TestGuestOnly(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("logged in", func(t *testing.T) {
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
req := testutils.MakeReq(server.URL, "GET", "/", "")
|
||||
res := testutils.HTTPAuthDo(t, db, req, user)
|
||||
|
||||
|
|
@ -67,8 +67,7 @@ func TestGuestOnly(t *testing.T) {
|
|||
func TestAuth(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
|
||||
session := database.Session{
|
||||
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
|
||||
|
|
@ -175,7 +174,7 @@ func TestAuth(t *testing.T) {
|
|||
func TestTokenAuth(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
tok := database.Token{
|
||||
UserID: user.ID,
|
||||
Type: database.TokenTypeResetPassword,
|
||||
|
|
@ -241,9 +240,8 @@ func TestWithAccount(t *testing.T) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
t.Run("user with account", func(t *testing.T) {
|
||||
user := testutils.SetupUserData(db)
|
||||
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
|
||||
t.Run("authenticated user", func(t *testing.T) {
|
||||
user := testutils.SetupUserData(db, "alice@test.com", "pass1234")
|
||||
|
||||
server := httptest.NewServer(Auth(db, handler, nil))
|
||||
defer server.Close()
|
||||
|
|
@ -253,17 +251,4 @@ func TestWithAccount(t *testing.T) {
|
|||
|
||||
assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
|
||||
})
|
||||
|
||||
t.Run("user without account", func(t *testing.T) {
|
||||
user := testutils.SetupUserData(db)
|
||||
// Note: not creating account for this user
|
||||
|
||||
server := httptest.NewServer(Auth(db, handler, nil))
|
||||
defer server.Close()
|
||||
|
||||
req := testutils.MakeReq(server.URL, "GET", "/", "")
|
||||
res := testutils.HTTPAuthDo(t, db, req, user)
|
||||
|
||||
assert.Equal(t, res.StatusCode, http.StatusForbidden, "status code mismatch")
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ import (
|
|||
func TestGetNote(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
@ -98,7 +98,7 @@ func TestGetNote(t *testing.T) {
|
|||
func TestGetNote_nonexistent(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ import (
|
|||
func TestViewNote(t *testing.T) {
|
||||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
user := testutils.SetupUserData(db)
|
||||
anotherUser := testutils.SetupUserData(db)
|
||||
user := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
anotherUser := testutils.SetupUserData(db, "another@test.com", "password123")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: testutils.MustUUID(t),
|
||||
|
|
|
|||
|
|
@ -29,9 +29,9 @@ type Session struct {
|
|||
}
|
||||
|
||||
// New returns a new session for the given user
|
||||
func New(user database.User, account database.Account) Session {
|
||||
func New(user database.User) Session {
|
||||
return Session{
|
||||
UUID: user.UUID,
|
||||
Email: account.Email.String,
|
||||
Email: user.Email.String,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,33 +27,34 @@ import (
|
|||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
u1 := database.User{UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb"}
|
||||
a1 := database.Account{Email: database.ToNullString("alice@example.com")}
|
||||
u1 := database.User{
|
||||
UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb",
|
||||
Email: database.ToNullString("alice@example.com"),
|
||||
}
|
||||
|
||||
u2 := database.User{UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e"}
|
||||
a2 := database.Account{Email: database.ToNullString("bob@example.com")}
|
||||
u2 := database.User{
|
||||
UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e",
|
||||
Email: database.ToNullString("bob@example.com"),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
user database.User
|
||||
account database.Account
|
||||
user database.User
|
||||
}{
|
||||
{
|
||||
user: u1,
|
||||
account: a1,
|
||||
user: u1,
|
||||
},
|
||||
{
|
||||
user: u2,
|
||||
account: a2,
|
||||
user: u2,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("user %d", idx), func(t *testing.T) {
|
||||
// Execute
|
||||
got := New(tc.user, tc.account)
|
||||
got := New(tc.user)
|
||||
expected := Session{
|
||||
UUID: tc.user.UUID,
|
||||
Email: tc.account.Email.String,
|
||||
Email: tc.user.Email.String,
|
||||
}
|
||||
|
||||
assert.DeepEqual(t, got, expected, "result mismatch")
|
||||
|
|
|
|||
|
|
@ -82,9 +82,6 @@ func ClearData(db *gorm.DB) {
|
|||
if err := db.Where("1 = 1").Delete(&database.Session{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear sessions"))
|
||||
}
|
||||
if err := db.Where("1 = 1").Delete(&database.Account{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear accounts"))
|
||||
}
|
||||
if err := db.Where("1 = 1").Delete(&database.User{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear users"))
|
||||
}
|
||||
|
|
@ -99,15 +96,22 @@ func MustUUID(t *testing.T) string {
|
|||
return uuid
|
||||
}
|
||||
|
||||
// SetupUserData creates and returns a new user for testing purposes
|
||||
func SetupUserData(db *gorm.DB) database.User {
|
||||
// SetupUserData creates and returns a new user with email and password for testing purposes
|
||||
func SetupUserData(db *gorm.DB, email, password string) database.User {
|
||||
uuid, err := helpers.GenUUID()
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "Failed to generate UUID"))
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "Failed to hash password"))
|
||||
}
|
||||
|
||||
user := database.User{
|
||||
UUID: uuid,
|
||||
UUID: uuid,
|
||||
Email: database.ToNullString(email),
|
||||
Password: database.ToNullString(string(hashedPassword)),
|
||||
}
|
||||
|
||||
if err := db.Save(&user).Error; err != nil {
|
||||
|
|
@ -117,28 +121,6 @@ func SetupUserData(db *gorm.DB) database.User {
|
|||
return user
|
||||
}
|
||||
|
||||
// SetupAccountData creates and returns a new account for the user
|
||||
func SetupAccountData(db *gorm.DB, user database.User, email, password string) database.Account {
|
||||
account := database.Account{
|
||||
UserID: user.ID,
|
||||
}
|
||||
if email != "" {
|
||||
account.Email = database.ToNullString(email)
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "Failed to hash password"))
|
||||
}
|
||||
account.Password = database.ToNullString(string(hashedPassword))
|
||||
|
||||
if err := db.Save(&account).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to prepare account"))
|
||||
}
|
||||
|
||||
return account
|
||||
}
|
||||
|
||||
// SetupSession creates and returns a new user session
|
||||
func SetupSession(db *gorm.DB, user database.User) database.Session {
|
||||
session := database.Session{
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func TestCreate(t *testing.T) {
|
|||
db := testutils.InitMemoryDB(t)
|
||||
|
||||
// Set up
|
||||
u := testutils.SetupUserData(db)
|
||||
u := testutils.SetupUserData(db, "user@test.com", "password123")
|
||||
|
||||
// Execute
|
||||
tok, err := Create(db, u.ID, tc.kind)
|
||||
|
|
|
|||
|
|
@ -50,9 +50,8 @@ type Alert struct {
|
|||
type Data struct {
|
||||
Alert *Alert
|
||||
// CSRF template.HTML
|
||||
User *database.User
|
||||
Account *database.Account
|
||||
Yield map[string]interface{}
|
||||
User *database.User
|
||||
Yield map[string]interface{}
|
||||
}
|
||||
|
||||
func getErrMessage(err error) string {
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@
|
|||
</div>
|
||||
|
||||
<div class="email">
|
||||
{{.Account.Email.String}}
|
||||
{{.User.Email.String}}
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -108,14 +108,13 @@ func (v *View) Render(w http.ResponseWriter, r *http.Request, data *Data, status
|
|||
}
|
||||
|
||||
vd.User = context.User(r.Context())
|
||||
vd.Account = context.Account(r.Context())
|
||||
|
||||
// Put user data in Yield
|
||||
if vd.Yield == nil {
|
||||
vd.Yield = map[string]interface{}{}
|
||||
}
|
||||
if vd.Account != nil {
|
||||
vd.Yield["Email"] = vd.Account.Email.String
|
||||
if vd.User != nil {
|
||||
vd.Yield["Email"] = vd.User.Email.String
|
||||
}
|
||||
vd.Yield["CurrentPath"] = r.URL.Path
|
||||
vd.Yield["Standalone"] = buildinfo.Standalone
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue