diff --git a/pkg/e2e/sync_test.go b/pkg/e2e/sync_test.go index c2a1c4ee..73ea3b31 100644 --- a/pkg/e2e/sync_test.go +++ b/pkg/e2e/sync_test.go @@ -136,8 +136,7 @@ func TestMain(m *testing.M) { // helpers func setupUser(t *testing.T, db *cliDatabase.DB) database.User { - user := apitest.SetupUserData(serverDb) - apitest.SetupAccountData(serverDb, user, "alice@example.com", "pass1234") + user := apitest.SetupUserData(serverDb, "alice@example.com", "pass1234") return user } @@ -4255,8 +4254,7 @@ func TestSync_EmptyServer(t *testing.T) { // Step 1: Set up user on Server A and sync apiEndpointA := fmt.Sprintf("%s/api", serverA.URL) - userA := apitest.SetupUserData(serverDbA) - apitest.SetupAccountData(serverDbA, userA, "alice@example.com", "pass1234") + userA := apitest.SetupUserData(serverDbA, "alice@example.com", "pass1234") sessionA := apitest.SetupSession(serverDbA, userA) cliDatabase.MustExec(t, "inserting session_key", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKey, sessionA.Key) cliDatabase.MustExec(t, "inserting session_key_expiry", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKeyExpiry, sessionA.ExpiresAt.Unix()) @@ -4280,8 +4278,7 @@ func TestSync_EmptyServer(t *testing.T) { apiEndpointB := fmt.Sprintf("%s/api", serverB.URL) // Set up user on Server B - userB := apitest.SetupUserData(serverDbB) - apitest.SetupAccountData(serverDbB, userB, "alice@example.com", "pass1234") + userB := apitest.SetupUserData(serverDbB, "alice@example.com", "pass1234") sessionB := apitest.SetupSession(serverDbB, userB) cliDatabase.MustExec(t, "updating session_key for B", ctx.DB, "UPDATE system SET value = ? WHERE key = ?", sessionB.Key, consts.SystemSessionKey) cliDatabase.MustExec(t, "updating session_key_expiry for B", ctx.DB, "UPDATE system SET value = ? WHERE key = ?", sessionB.ExpiresAt.Unix(), consts.SystemSessionKeyExpiry) diff --git a/pkg/server/app/books_test.go b/pkg/server/app/books_test.go index 85df4770..66a27077 100644 --- a/pkg/server/app/books_test.go +++ b/pkg/server/app/books_test.go @@ -56,10 +56,10 @@ func TestCreateBook(t *testing.T) { func() { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - anotherUser := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db, "another@test.com", "password123") testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) a := NewTest() @@ -122,10 +122,10 @@ func TestDeleteBook(t *testing.T) { func() { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - anotherUser := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db, "another@test.com", "password123") testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) book := database.Book{UserID: user.ID, Label: "js", Deleted: false} @@ -201,10 +201,10 @@ func TestUpdateBook(t *testing.T) { func() { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - anotherUser := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db, "another@test.com", "password123") testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel} diff --git a/pkg/server/app/helpers_test.go b/pkg/server/app/helpers_test.go index 2c7a2828..ad309514 100644 --- a/pkg/server/app/helpers_test.go +++ b/pkg/server/app/helpers_test.go @@ -48,7 +48,7 @@ func TestIncremenetUserUSN(t *testing.T) { func() { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) // execute diff --git a/pkg/server/app/notes_test.go b/pkg/server/app/notes_test.go index 42a0bd8a..38195079 100644 --- a/pkg/server/app/notes_test.go +++ b/pkg/server/app/notes_test.go @@ -77,11 +77,11 @@ func TestCreateNote(t *testing.T) { func() { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) fmt.Println(user) - anotherUser := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db, "another@test.com", "password123") testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false} @@ -130,7 +130,7 @@ func TestCreateNote(t *testing.T) { func TestCreateNote_EmptyBody(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") b1 := database.Book{UserID: user.ID, Label: "testBook"} testutils.MustExec(t, db.Save(&b1), "preparing book") @@ -169,10 +169,10 @@ func TestUpdateNote(t *testing.T) { t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case") - anotherUser := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db, "another@test.com", "password123") testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case") b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false} @@ -234,7 +234,7 @@ func TestUpdateNote(t *testing.T) { func TestUpdateNote_SameContent(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") b1 := database.Book{UserID: user.ID, Label: "testBook"} testutils.MustExec(t, db.Save(&b1), "preparing book") @@ -291,10 +291,10 @@ func TestDeleteNote(t *testing.T) { func() { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - anotherUser := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db, "another@test.com", "password123") testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) b1 := database.Book{UserID: user.ID, Label: "testBook"} @@ -351,7 +351,7 @@ func TestDeleteNote(t *testing.T) { func TestGetNotes_FTSSearch(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") b1 := database.Book{UserID: user.ID, Label: "testBook"} testutils.MustExec(t, db.Save(&b1), "preparing book") @@ -415,7 +415,7 @@ func TestGetNotes_FTSSearch(t *testing.T) { func TestGetNotes_FTSSearch_Snippet(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") b1 := database.Book{UserID: user.ID, Label: "testBook"} testutils.MustExec(t, db.Save(&b1), "preparing book") @@ -449,7 +449,7 @@ func TestGetNotes_FTSSearch_Snippet(t *testing.T) { func TestGetNotes_FTSSearch_ShortWord(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") b1 := database.Book{UserID: user.ID, Label: "testBook"} testutils.MustExec(t, db.Save(&b1), "preparing book") @@ -481,7 +481,7 @@ func TestGetNotes_FTSSearch_ShortWord(t *testing.T) { func TestGetNotes_All(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") b1 := database.Book{UserID: user.ID, Label: "testBook"} testutils.MustExec(t, db.Save(&b1), "preparing book") diff --git a/pkg/server/app/users.go b/pkg/server/app/users.go index b3993d55..6e9535e7 100644 --- a/pkg/server/app/users.go +++ b/pkg/server/app/users.go @@ -65,7 +65,7 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d tx := a.DB.Begin() var count int64 - if err := tx.Model(database.Account{}).Where("email = ?", email).Count(&count).Error; err != nil { + if err := tx.Model(&database.User{}).Where("email = ?", email).Count(&count).Error; err != nil { return database.User{}, pkgErrors.Wrap(err, "counting user") } if count > 0 { @@ -85,21 +85,14 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d } user := database.User{ - UUID: uuid, + UUID: uuid, + Email: database.ToNullString(email), + Password: database.ToNullString(string(hashedPassword)), } if err = tx.Save(&user).Error; err != nil { tx.Rollback() return database.User{}, pkgErrors.Wrap(err, "saving user") } - account := database.Account{ - Email: database.ToNullString(email), - Password: database.ToNullString(string(hashedPassword)), - UserID: user.ID, - } - if err = tx.Save(&account).Error; err != nil { - tx.Rollback() - return database.User{}, pkgErrors.Wrap(err, "saving account") - } if err := a.TouchLastLoginAt(user, tx); err != nil { tx.Rollback() @@ -111,42 +104,36 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d return user, nil } -// GetAccountByEmail finds an account by email -func (a *App) GetAccountByEmail(email string) (*database.Account, error) { - var account database.Account - err := a.DB.Where("email = ?", email).First(&account).Error +// GetUserByEmail finds a user by email +func (a *App) GetUserByEmail(email string) (*database.User, error) { + var user database.User + err := a.DB.Where("email = ?", email).First(&user).Error if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrNotFound } else if err != nil { return nil, err } - return &account, nil + return &user, nil } // Authenticate authenticates a user func (a *App) Authenticate(email, password string) (*database.User, error) { - account, err := a.GetAccountByEmail(email) + user, err := a.GetUserByEmail(email) if err != nil { return nil, err } - err = bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(password)) + err = bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte(password)) if err != nil { return nil, ErrLoginInvalid } - var user database.User - err = a.DB.Where("id = ?", account.UserID).First(&user).Error - if err != nil { - return nil, pkgErrors.Wrap(err, "finding user") - } - - return &user, nil + return user, nil } -// UpdateAccountPassword updates an account's password with validation -func UpdateAccountPassword(db *gorm.DB, account *database.Account, newPassword string) error { +// UpdateUserPassword updates a user's password with validation +func UpdateUserPassword(db *gorm.DB, user *database.User, newPassword string) error { // Validate password if err := validatePassword(newPassword); err != nil { return err @@ -159,25 +146,25 @@ func UpdateAccountPassword(db *gorm.DB, account *database.Account, newPassword s } // Update the password - if err := db.Model(&account).Update("password", string(hashedPassword)).Error; err != nil { + if err := db.Model(&user).Update("password", string(hashedPassword)).Error; err != nil { return pkgErrors.Wrap(err, "updating password") } return nil } -// RemoveUser removes a user and their account from the system +// RemoveUser removes a user from the system // Returns an error if the user has any notes or books func (a *App) RemoveUser(email string) error { - // Find the account and user - account, err := a.GetAccountByEmail(email) + // Find the user + user, err := a.GetUserByEmail(email) if err != nil { return err } // Check if user has any notes var noteCount int64 - if err := a.DB.Model(&database.Note{}).Where("user_id = ? AND deleted = ?", account.UserID, false).Count(¬eCount).Error; err != nil { + if err := a.DB.Model(&database.Note{}).Where("user_id = ? AND deleted = ?", user.ID, false).Count(¬eCount).Error; err != nil { return pkgErrors.Wrap(err, "counting notes") } if noteCount > 0 { @@ -186,34 +173,18 @@ func (a *App) RemoveUser(email string) error { // Check if user has any books var bookCount int64 - if err := a.DB.Model(&database.Book{}).Where("user_id = ? AND deleted = ?", account.UserID, false).Count(&bookCount).Error; err != nil { + if err := a.DB.Model(&database.Book{}).Where("user_id = ? AND deleted = ?", user.ID, false).Count(&bookCount).Error; err != nil { return pkgErrors.Wrap(err, "counting books") } if bookCount > 0 { return ErrUserHasExistingResources } - // Delete account and user in a transaction - tx := a.DB.Begin() - - if err := tx.Delete(&account).Error; err != nil { - tx.Rollback() - return pkgErrors.Wrap(err, "deleting account") - } - - var user database.User - if err := tx.Where("id = ?", account.UserID).First(&user).Error; err != nil { - tx.Rollback() - return pkgErrors.Wrap(err, "finding user") - } - - if err := tx.Delete(&user).Error; err != nil { - tx.Rollback() + // Delete user + if err := a.DB.Delete(&user).Error; err != nil { return pkgErrors.Wrap(err, "deleting user") } - tx.Commit() - return nil } diff --git a/pkg/server/app/users_test.go b/pkg/server/app/users_test.go index a4c3a60d..90184fec 100644 --- a/pkg/server/app/users_test.go +++ b/pkg/server/app/users_test.go @@ -82,21 +82,20 @@ func TestCreateUser_ProValue(t *testing.T) { } -func TestGetAccountByEmail(t *testing.T) { +func TestGetUserByEmail(t *testing.T) { t.Run("success", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@example.com", "password123") + user := testutils.SetupUserData(db, "alice@example.com", "password123") a := NewTest() a.DB = db - account, err := a.GetAccountByEmail("alice@example.com") + foundUser, err := a.GetUserByEmail("alice@example.com") assert.Equal(t, err, nil, "should not error") - assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch") - assert.Equal(t, account.UserID, user.ID, "user ID mismatch") + assert.Equal(t, foundUser.Email.String, "alice@example.com", "email mismatch") + assert.Equal(t, foundUser.ID, user.ID, "user ID mismatch") }) t.Run("not found", func(t *testing.T) { @@ -105,10 +104,10 @@ func TestGetAccountByEmail(t *testing.T) { a := NewTest() a.DB = db - account, err := a.GetAccountByEmail("nonexistent@example.com") + user, err := a.GetUserByEmail("nonexistent@example.com") assert.Equal(t, err, ErrNotFound, "should return ErrNotFound") - assert.Equal(t, account, (*database.Account)(nil), "account should be nil") + assert.Equal(t, user, (*database.User)(nil), "user should be nil") }) } @@ -124,25 +123,21 @@ func TestCreateUser(t *testing.T) { var userCount int64 testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") - assert.Equal(t, userCount, int64(1), "book count mismatch") + assert.Equal(t, userCount, int64(1), "user count mismatch") - var accountCount int64 - var accountRecord database.Account - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") - testutils.MustExec(t, db.First(&accountRecord), "finding account") + var userRecord database.User + testutils.MustExec(t, db.First(&userRecord), "finding user") - assert.Equal(t, accountCount, int64(1), "account count mismatch") - assert.Equal(t, accountRecord.Email.String, "alice@example.com", "account email mismatch") + assert.Equal(t, userRecord.Email.String, "alice@example.com", "user email mismatch") - passwordErr := bcrypt.CompareHashAndPassword([]byte(accountRecord.Password.String), []byte("pass1234")) + passwordErr := bcrypt.CompareHashAndPassword([]byte(userRecord.Password.String), []byte("pass1234")) assert.Equal(t, passwordErr, nil, "Password mismatch") }) t.Run("duplicate email", func(t *testing.T) { db := testutils.InitMemoryDB(t) - aliceUser := testutils.SetupUserData(db) - testutils.SetupAccountData(db, aliceUser, "alice@example.com", "somepassword") + testutils.SetupUserData(db, "alice@example.com", "somepassword") a := NewTest() a.DB = db @@ -150,116 +145,109 @@ func TestCreateUser(t *testing.T) { assert.Equal(t, err, ErrDuplicateEmail, "error mismatch") - var userCount, accountCount int64 + var userCount int64 testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") assert.Equal(t, userCount, int64(1), "user count mismatch") - assert.Equal(t, accountCount, int64(1), "account count mismatch") }) } -func TestUpdateAccountPassword(t *testing.T) { +func TestUpdateUserPassword(t *testing.T) { t.Run("success", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123") - err := UpdateAccountPassword(db, &account, "newpassword123") + err := UpdateUserPassword(db, &user, "newpassword123") assert.Equal(t, err, nil, "should not error") // Verify password was updated in database - var updatedAccount database.Account - testutils.MustExec(t, db.Where("id = ?", account.ID).First(&updatedAccount), "finding updated account") + var updatedUser database.User + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&updatedUser), "finding updated user") // Verify new password works - passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123")) + passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123")) assert.Equal(t, passwordErr, nil, "New password should match") // Verify old password no longer works - oldPasswordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("oldpassword123")) + oldPasswordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("oldpassword123")) assert.NotEqual(t, oldPasswordErr, nil, "Old password should not match") }) t.Run("password too short", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123") - err := UpdateAccountPassword(db, &account, "short") + err := UpdateUserPassword(db, &user, "short") assert.Equal(t, err, ErrPasswordTooShort, "should return ErrPasswordTooShort") // Verify password was NOT updated in database - var unchangedAccount database.Account - testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account") + var unchangedUser database.User + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user") // Verify old password still works - passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123")) + passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123")) assert.Equal(t, passwordErr, nil, "Old password should still match") }) t.Run("empty password", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123") - err := UpdateAccountPassword(db, &account, "") + err := UpdateUserPassword(db, &user, "") assert.Equal(t, err, ErrPasswordTooShort, "should return ErrPasswordTooShort") // Verify password was NOT updated in database - var unchangedAccount database.Account - testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account") + var unchangedUser database.User + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user") // Verify old password still works - passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123")) + passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123")) assert.Equal(t, passwordErr, nil, "Old password should still match") }) t.Run("transaction rollback", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123") - // Start a transaction and rollback to verify UpdateAccountPassword respects transactions + // Start a transaction and rollback to verify UpdateUserPassword respects transactions tx := db.Begin() - err := UpdateAccountPassword(tx, &account, "newpassword123") + err := UpdateUserPassword(tx, &user, "newpassword123") assert.Equal(t, err, nil, "should not error") tx.Rollback() // Verify password was NOT updated after rollback - var unchangedAccount database.Account - testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account") + var unchangedUser database.User + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&unchangedUser), "finding unchanged user") // Verify old password still works - passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123")) + passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedUser.Password.String), []byte("oldpassword123")) assert.Equal(t, passwordErr, nil, "Old password should still match after rollback") }) t.Run("transaction commit", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + user := testutils.SetupUserData(db, "alice@example.com", "oldpassword123") - // Start a transaction and commit to verify UpdateAccountPassword respects transactions + // Start a transaction and commit to verify UpdateUserPassword respects transactions tx := db.Begin() - err := UpdateAccountPassword(tx, &account, "newpassword123") + err := UpdateUserPassword(tx, &user, "newpassword123") assert.Equal(t, err, nil, "should not error") tx.Commit() // Verify password was updated after commit - var updatedAccount database.Account - testutils.MustExec(t, db.Where("id = ?", account.ID).First(&updatedAccount), "finding updated account") + var updatedUser database.User + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&updatedUser), "finding updated user") // Verify new password works - passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123")) + passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123")) assert.Equal(t, passwordErr, nil, "New password should match after commit") }) } @@ -268,8 +256,7 @@ func TestRemoveUser(t *testing.T) { t.Run("success", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@example.com", "password123") + testutils.SetupUserData(db, "alice@example.com", "password123") a := NewTest() a.DB = db @@ -282,11 +269,6 @@ func TestRemoveUser(t *testing.T) { var userCount int64 testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users") assert.Equal(t, userCount, int64(0), "user should be deleted") - - // Verify account was deleted - var accountCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts") - assert.Equal(t, accountCount, int64(0), "account should be deleted") }) t.Run("user not found", func(t *testing.T) { @@ -303,8 +285,7 @@ func TestRemoveUser(t *testing.T) { t.Run("user has notes", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@example.com", "password123") + user := testutils.SetupUserData(db, "alice@example.com", "password123") book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false} testutils.MustExec(t, db.Save(&book), "creating book") @@ -324,17 +305,12 @@ func TestRemoveUser(t *testing.T) { testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users") assert.Equal(t, userCount, int64(1), "user should not be deleted") - // Verify account was NOT deleted - var accountCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts") - assert.Equal(t, accountCount, int64(1), "account should not be deleted") }) t.Run("user has books", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@example.com", "password123") + user := testutils.SetupUserData(db, "alice@example.com", "password123") book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false} testutils.MustExec(t, db.Save(&book), "creating book") @@ -351,17 +327,12 @@ func TestRemoveUser(t *testing.T) { testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users") assert.Equal(t, userCount, int64(1), "user should not be deleted") - // Verify account was NOT deleted - var accountCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts") - assert.Equal(t, accountCount, int64(1), "account should not be deleted") }) t.Run("user has deleted notes and books", func(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@example.com", "password123") + user := testutils.SetupUserData(db, "alice@example.com", "password123") book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false} testutils.MustExec(t, db.Save(&book), "creating book") @@ -385,9 +356,5 @@ func TestRemoveUser(t *testing.T) { testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users") assert.Equal(t, userCount, int64(0), "user should be deleted") - // Verify account was deleted - var accountCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts") - assert.Equal(t, accountCount, int64(0), "account should be deleted") }) } diff --git a/pkg/server/cmd/user.go b/pkg/server/cmd/user.go index 6123cdae..ec5b4ea2 100644 --- a/pkg/server/cmd/user.go +++ b/pkg/server/cmd/user.go @@ -81,12 +81,12 @@ func userRemoveCmd(args []string, stdin io.Reader) { defer cleanup() // Check if user exists first - _, err := a.GetAccountByEmail(*email) + _, err := a.GetUserByEmail(*email) if err != nil { if errors.Is(err, app.ErrNotFound) { fmt.Printf("Error: user with email %s not found\n", *email) } else { - log.ErrorWrap(err, "finding account") + log.ErrorWrap(err, "finding user") } os.Exit(1) } @@ -133,19 +133,19 @@ func userResetPasswordCmd(args []string) { a, cleanup := setupAppWithDB(fs, *dbPath) defer cleanup() - // Find the account - account, err := a.GetAccountByEmail(*email) + // Find the user + user, err := a.GetUserByEmail(*email) if err != nil { if errors.Is(err, app.ErrNotFound) { fmt.Printf("Error: user with email %s not found\n", *email) } else { - log.ErrorWrap(err, "finding account") + log.ErrorWrap(err, "finding user") } os.Exit(1) } // Update the password - if err := app.UpdateAccountPassword(a.DB, account, *password); err != nil { + if err := app.UpdateUserPassword(a.DB, user, *password); err != nil { log.ErrorWrap(err, "updating password") os.Exit(1) } diff --git a/pkg/server/cmd/user_test.go b/pkg/server/cmd/user_test.go index d3536d52..ea81832a 100644 --- a/pkg/server/cmd/user_test.go +++ b/pkg/server/cmd/user_test.go @@ -45,9 +45,9 @@ func TestUserCreateCmd(t *testing.T) { testutils.MustExec(t, db.Model(&database.User{}).Count(&count), "counting users") assert.Equal(t, count, int64(1), "should have 1 user") - var account database.Account - testutils.MustExec(t, db.Where("email = ?", "test@example.com").First(&account), "finding account") - assert.Equal(t, account.Email.String, "test@example.com", "email mismatch") + var user database.User + testutils.MustExec(t, db.Where("email = ?", "test@example.com").First(&user), "finding user") + assert.Equal(t, user.Email.String, "test@example.com", "email mismatch") } func TestUserRemoveCmd(t *testing.T) { @@ -55,8 +55,7 @@ func TestUserRemoveCmd(t *testing.T) { // Create a user first db := testutils.InitDB(tmpDB) - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "test@example.com", "password123") + testutils.SetupUserData(db, "test@example.com", "password123") sqlDB, _ := db.DB() sqlDB.Close() @@ -81,9 +80,8 @@ func TestUserResetPasswordCmd(t *testing.T) { // Create a user first db := testutils.InitDB(tmpDB) - user := testutils.SetupUserData(db) - account := testutils.SetupAccountData(db, user, "test@example.com", "oldpassword123") - oldPasswordHash := account.Password.String + user := testutils.SetupUserData(db, "test@example.com", "oldpassword123") + oldPasswordHash := user.Password.String sqlDB, _ := db.DB() sqlDB.Close() @@ -97,18 +95,18 @@ func TestUserResetPasswordCmd(t *testing.T) { sqlDB2.Close() }() - var updatedAccount database.Account - testutils.MustExec(t, db2.Where("email = ?", "test@example.com").First(&updatedAccount), "finding account") + var updatedUser database.User + testutils.MustExec(t, db2.Where("email = ?", "test@example.com").First(&updatedUser), "finding user") // Verify password hash changed - assert.Equal(t, updatedAccount.Password.String != oldPasswordHash, true, "password hash should be different") - assert.Equal(t, len(updatedAccount.Password.String) > 0, true, "password should be set") + assert.Equal(t, updatedUser.Password.String != oldPasswordHash, true, "password hash should be different") + assert.Equal(t, len(updatedUser.Password.String) > 0, true, "password should be set") // Verify new password works - err := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123")) + err := bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("newpassword123")) assert.Equal(t, err, nil, "new password should match") // Verify old password doesn't work - err = bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("oldpassword123")) + err = bcrypt.CompareHashAndPassword([]byte(updatedUser.Password.String), []byte("oldpassword123")) assert.Equal(t, err != nil, true, "old password should not match") } diff --git a/pkg/server/context/user.go b/pkg/server/context/user.go index 77d66916..64171df7 100644 --- a/pkg/server/context/user.go +++ b/pkg/server/context/user.go @@ -25,9 +25,8 @@ import ( ) const ( - userKey privateKey = "user" - accountKey privateKey = "account" - tokenKey privateKey = "token" + userKey privateKey = "user" + tokenKey privateKey = "token" ) type privateKey string @@ -37,11 +36,6 @@ func WithUser(ctx context.Context, user *database.User) context.Context { return context.WithValue(ctx, userKey, user) } -// WithAccount creates a new context with the given account -func WithAccount(ctx context.Context, account *database.Account) context.Context { - return context.WithValue(ctx, accountKey, account) -} - // WithToken creates a new context with the given user func WithToken(ctx context.Context, tok *database.Token) context.Context { return context.WithValue(ctx, tokenKey, tok) @@ -59,17 +53,6 @@ func User(ctx context.Context) *database.User { return nil } -// Account retrieves an account from the given context. -func Account(ctx context.Context) *database.Account { - if temp := ctx.Value(accountKey); temp != nil { - if account, ok := temp.(*database.Account); ok { - return account - } - } - - return nil -} - // Token retrieves a token from the given context. func Token(ctx context.Context) *database.Token { if temp := ctx.Value(tokenKey); temp != nil { diff --git a/pkg/server/controllers/books_test.go b/pkg/server/controllers/books_test.go index d59dcd17..e2302f22 100644 --- a/pkg/server/controllers/books_test.go +++ b/pkg/server/controllers/books_test.go @@ -50,10 +50,8 @@ func TestGetBooks(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") - anotherUser := testutils.SetupUserData(db) - testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234") b1 := database.Book{ UUID: testutils.MustUUID(t), @@ -143,10 +141,8 @@ func TestGetBooksByName(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") - anotherUser := testutils.SetupUserData(db) - testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234") b1 := database.Book{ UUID: testutils.MustUUID(t), @@ -212,10 +208,8 @@ func TestGetBook(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") - anotherUser := testutils.SetupUserData(db) - testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234") b1 := database.Book{ UUID: testutils.MustUUID(t), @@ -276,10 +270,8 @@ func TestGetBookNonOwner(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") - nonOwner := testutils.SetupUserData(db) - testutils.SetupAccountData(db, nonOwner, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") + nonOwner := testutils.SetupUserData(db, "bob@test.com", "pass1234") b1 := database.Book{ UUID: testutils.MustUUID(t), @@ -314,8 +306,7 @@ func TestCreateBook(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`) @@ -375,8 +366,7 @@ func TestCreateBook(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") b1 := database.Book{ @@ -465,8 +455,7 @@ func TestUpdateBook(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") b1 := database.Book{ @@ -550,11 +539,9 @@ func TestDeleteBook(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn") - anotherUser := testutils.SetupUserData(db) - testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234") testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn") b1 := database.Book{ diff --git a/pkg/server/controllers/notes_test.go b/pkg/server/controllers/notes_test.go index ce3f20c7..b69f7c1e 100644 --- a/pkg/server/controllers/notes_test.go +++ b/pkg/server/controllers/notes_test.go @@ -63,10 +63,8 @@ func TestGetNotes(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") - anotherUser := testutils.SetupUserData(db) - testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db, "bob@test.com", "pass1234") b1 := database.Book{ UUID: testutils.MustUUID(t), @@ -187,10 +185,8 @@ func TestGetNote(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "user@test.com", "pass1234") - anotherUser := testutils.SetupUserData(db) - testutils.SetupAccountData(db, anotherUser, "another@test.com", "pass1234") + user := testutils.SetupUserData(db, "user@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db, "another@test.com", "pass1234") b1 := database.Book{ UUID: testutils.MustUUID(t), @@ -318,8 +314,7 @@ func TestCreateNote(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") b1 := database.Book{ @@ -400,8 +395,7 @@ func TestDeleteNote(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn") b1 := database.Book{ @@ -559,8 +553,7 @@ func TestUpdateNote(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") diff --git a/pkg/server/controllers/users.go b/pkg/server/controllers/users.go index b1945d8f..bf1e6384 100644 --- a/pkg/server/controllers/users.go +++ b/pkg/server/controllers/users.go @@ -307,23 +307,23 @@ func (u *Users) CreateResetToken(w http.ResponseWriter, r *http.Request) { return } - var account database.Account - err := u.app.DB.Where("email = ?", form.Email).First(&account).Error + var user database.User + err := u.app.DB.Where("email = ?", form.Email).First(&user).Error if errors.Is(err, gorm.ErrRecordNotFound) { return } if err != nil { - handleHTMLError(w, r, err, "finding account", u.PasswordResetView, vd) + handleHTMLError(w, r, err, "finding user", u.PasswordResetView, vd) return } - resetToken, err := token.Create(u.app.DB, account.UserID, database.TokenTypeResetPassword) + resetToken, err := token.Create(u.app.DB, user.ID, database.TokenTypeResetPassword) if err != nil { handleHTMLError(w, r, err, "generating token", u.PasswordResetView, vd) return } - if err := u.app.SendPasswordResetEmail(account.Email.String, resetToken.Value); err != nil { + if err := u.app.SendPasswordResetEmail(user.Email.String, resetToken.Value); err != nil { handleHTMLError(w, r, err, "sending password reset email", u.PasswordResetView, vd) return } @@ -396,8 +396,8 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) { return } - var account database.Account - if err := u.app.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil { + var user database.User + if err := u.app.DB.Where("id = ?", token.UserID).First(&user).Error; err != nil { handleHTMLError(w, r, err, "finding user", u.PasswordResetConfirmView, vd) return } @@ -405,7 +405,7 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) { tx := u.app.DB.Begin() // Update the password - if err := app.UpdateAccountPassword(tx, &account, params.Password); err != nil { + if err := app.UpdateUserPassword(tx, &user, params.Password); err != nil { tx.Rollback() handleHTMLError(w, r, err, "updating password", u.PasswordResetConfirmView, vd) return @@ -417,7 +417,7 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) { return } - if err := u.app.DeleteUserSessions(tx, account.UserID); err != nil { + if err := u.app.DeleteUserSessions(tx, user.ID); err != nil { tx.Rollback() handleHTMLError(w, r, err, "deleting user sessions", u.PasswordResetConfirmView, vd) return @@ -425,19 +425,13 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) { tx.Commit() - var user database.User - if err := u.app.DB.Where("id = ?", account.UserID).First(&user).Error; err != nil { - handleHTMLError(w, r, err, "finding user", u.PasswordResetConfirmView, vd) - return - } - alert := views.Alert{ Level: views.AlertLvlSuccess, Message: "Password reset successful", } views.RedirectAlert(w, r, "/login", http.StatusFound, alert) - if err := u.app.SendPasswordResetAlertEmail(account.Email.String); err != nil { + if err := u.app.SendPasswordResetAlertEmail(user.Email.String); err != nil { log.ErrorWrap(err, "sending password reset email") } } @@ -493,14 +487,8 @@ func (u *Users) PasswordUpdate(w http.ResponseWriter, r *http.Request) { return } - var account database.Account - if err := u.app.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil { - handleHTMLError(w, r, err, "getting account", u.SettingView, vd) - return - } - password := []byte(form.OldPassword) - if err := bcrypt.CompareHashAndPassword([]byte(account.Password.String), password); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(user.Password.String), password); err != nil { log.WithFields(log.Fields{ "user_id": user.ID, }).Warn("invalid password update attempt") @@ -508,7 +496,7 @@ func (u *Users) PasswordUpdate(w http.ResponseWriter, r *http.Request) { return } - if err := app.UpdateAccountPassword(u.app.DB, &account, form.NewPassword); err != nil { + if err := app.UpdateUserPassword(u.app.DB, user, form.NewPassword); err != nil { handleHTMLError(w, r, err, "updating password", u.SettingView, vd) return } @@ -534,12 +522,6 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) { return } - var account database.Account - if err := u.app.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil { - handleHTMLError(w, r, err, "getting account", u.SettingView, vd) - return - } - var form updateProfileForm if err := parseRequestData(r, &form); err != nil { handleHTMLError(w, r, err, "parsing payload", u.SettingView, vd) @@ -547,7 +529,7 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) { } password := []byte(form.Password) - if err := bcrypt.CompareHashAndPassword([]byte(account.Password.String), password); err != nil { + if err := bcrypt.CompareHashAndPassword([]byte(user.Password.String), password); err != nil { log.WithFields(log.Fields{ "user_id": user.ID, }).Warn("invalid email update attempt") @@ -561,23 +543,13 @@ func (u *Users) ProfileUpdate(w http.ResponseWriter, r *http.Request) { return } - tx := u.app.DB.Begin() - if err := tx.Save(&user).Error; err != nil { - tx.Rollback() + user.Email.String = form.Email + + if err := u.app.DB.Save(&user).Error; err != nil { handleHTMLError(w, r, err, "saving user", u.SettingView, vd) return } - account.Email.String = form.Email - - if err := tx.Save(&account).Error; err != nil { - tx.Rollback() - handleHTMLError(w, r, err, "saving account", u.SettingView, vd) - return - } - - tx.Commit() - alert := views.Alert{ Level: views.AlertLvlSuccess, Message: "Email change successful", diff --git a/pkg/server/controllers/users_test.go b/pkg/server/controllers/users_test.go index 643cb016..03f6c53e 100644 --- a/pkg/server/controllers/users_test.go +++ b/pkg/server/controllers/users_test.go @@ -98,15 +98,14 @@ func TestJoin(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusFound, "") - var account database.Account - testutils.MustExec(t, db.Where("email = ?", tc.email).First(&account), "finding account") - assert.Equal(t, account.Email.String, tc.email, "Email mismatch") - assert.NotEqual(t, account.UserID, 0, "UserID mismatch") - passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password)) + var user database.User + testutils.MustExec(t, db.Where("email = ?", tc.email).First(&user), "finding account") + assert.Equal(t, user.Email.String, tc.email, "Email mismatch") + assert.NotEqual(t, user.ID, 0, "UserID mismatch") + passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte(tc.password)) assert.Equal(t, passwordErr, nil, "Password mismatch") - var user database.User - testutils.MustExec(t, db.Where("id = ?", account.UserID).First(&user), "finding user") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&user), "finding user") assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch") // welcome email @@ -140,11 +139,9 @@ func TestJoinError(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch") - var accountCount, userCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + var userCount int64 testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") - assert.Equal(t, accountCount, int64(0), "accountCount mismatch") assert.Equal(t, userCount, int64(0), "userCount mismatch") }) @@ -168,11 +165,9 @@ func TestJoinError(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch") - var accountCount, userCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + var userCount int64 testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") - assert.Equal(t, accountCount, int64(0), "accountCount mismatch") assert.Equal(t, userCount, int64(0), "userCount mismatch") }) @@ -198,11 +193,9 @@ func TestJoinError(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch") - var accountCount, userCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + var userCount int64 testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") - assert.Equal(t, accountCount, int64(0), "accountCount mismatch") assert.Equal(t, userCount, int64(0), "userCount mismatch") }) } @@ -217,8 +210,7 @@ func TestJoinDuplicateEmail(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db, "alice@example.com", "somepassword") dat := url.Values{} dat.Set("email", "alice@example.com") @@ -232,15 +224,13 @@ func TestJoinDuplicateEmail(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch") - var accountCount, userCount, verificationTokenCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + var userCount, verificationTokenCount int64 testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") testutils.MustExec(t, db.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token") var user database.User testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user") - assert.Equal(t, accountCount, int64(1), "account count mismatch") assert.Equal(t, userCount, int64(1), "user count mismatch") assert.Equal(t, verificationTokenCount, int64(0), "verification_token should not have been created") assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch") @@ -268,11 +258,9 @@ func TestJoinDisabled(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusNotFound, "status code mismatch") - var accountCount, userCount int64 - testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + var userCount int64 testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") - assert.Equal(t, accountCount, int64(0), "account count mismatch") assert.Equal(t, userCount, int64(0), "user count mismatch") } @@ -286,8 +274,7 @@ func TestLogin(t *testing.T) { a.DB = db server := MustNewServer(t, &a) - u := testutils.SetupUserData(db) - testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") + _ = testutils.SetupUserData(db, "alice@example.com", "pass1234") defer server.Close() // Execute @@ -346,8 +333,7 @@ func TestLogin(t *testing.T) { a.DB = db server := MustNewServer(t, &a) - u := testutils.SetupUserData(db) - testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") + _ = testutils.SetupUserData(db, "alice@example.com", "pass1234") defer server.Close() var req *http.Request @@ -386,8 +372,7 @@ func TestLogin(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") + _ = testutils.SetupUserData(db, "alice@example.com", "pass1234") var req *http.Request if target == testutils.EndpointWeb { @@ -456,9 +441,8 @@ func TestLogout(t *testing.T) { a.DB = db server := MustNewServer(t, &a) - aliceUser := testutils.SetupUserData(db) - testutils.SetupAccountData(db, aliceUser, "alice@example.com", "pass1234") - anotherUser := testutils.SetupUserData(db) + aliceUser := testutils.SetupUserData(db, "alice@example.com", "pass1234") + anotherUser := testutils.SetupUserData(db, "bob@example.com", "password123") session1ExpiresAt := time.Now().Add(time.Hour * 24) session1 := database.Session{ @@ -570,8 +554,7 @@ func TestResetPassword(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword") + u := testutils.SetupUserData(db, "alice@example.com", "oldpassword") tok := database.Token{ UserID: u.ID, Value: "MivFxYiSMMA4An9dP24DNQ==", @@ -593,7 +576,7 @@ func TestResetPassword(t *testing.T) { } testutils.MustExec(t, db.Save(&s2), "preparing user session 2") - anotherUser := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db, "bob@example.com", "password123") testutils.MustExec(t, db.Save(&database.Session{ Key: "some-session-key-3", UserID: anotherUser.ID, @@ -613,12 +596,12 @@ func TestResetPassword(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch") var resetToken database.Token - var account database.Account + var user database.User testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token") - testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account") assert.NotEqual(t, resetToken.UsedAt, nil, "reset_token UsedAt mismatch") - passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword")) + passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte("newpassword")) assert.Equal(t, passwordErr, nil, "Password mismatch") var s1Count, s2Count int64 @@ -646,8 +629,7 @@ func TestResetPassword(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db, "alice@example.com", "somepassword") tok := database.Token{ UserID: u.ID, Value: "MivFxYiSMMA4An9dP24DNQ==", @@ -668,12 +650,12 @@ func TestResetPassword(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismatch") var resetToken database.Token - var account database.Account + var user database.User testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token") - testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account") - assert.Equal(t, acc.Password, account.Password, "password should not have been updated") - assert.Equal(t, acc.Password, account.Password, "password should not have been updated") + assert.Equal(t, u.Password, user.Password, "password should not have been updated") + assert.Equal(t, u.Password, user.Password, "password should not have been updated") assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil") }) @@ -687,8 +669,7 @@ func TestResetPassword(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db, "alice@example.com", "somepassword") tok := database.Token{ UserID: u.ID, Value: "MivFxYiSMMA4An9dP24DNQ==", @@ -710,10 +691,10 @@ func TestResetPassword(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusGone, "Status code mismatch") var resetToken database.Token - var account database.Account + var user database.User testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token") - testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account") - assert.Equal(t, acc.Password, account.Password, "password should not have been updated") + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "failed to find account") + assert.Equal(t, u.Password, user.Password, "password should not have been updated") assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil") }) @@ -727,8 +708,7 @@ func TestResetPassword(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db, "alice@example.com", "somepassword") usedAt := time.Now().Add(time.Hour * -11).UTC() tok := database.Token{ @@ -753,10 +733,10 @@ func TestResetPassword(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismatch") var resetToken database.Token - var account database.Account + var user database.User testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token") - testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account") - assert.Equal(t, acc.Password, account.Password, "password should not have been updated") + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "failed to find account") + assert.Equal(t, u.Password, user.Password, "password should not have been updated") resetTokenUsedAtUTC := resetToken.UsedAt.UTC() if resetTokenUsedAtUTC.Year() != usedAt.Year() || @@ -782,8 +762,7 @@ func TestCreateResetToken(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db, "alice@example.com", "somepassword") // Execute dat := url.Values{} @@ -816,8 +795,7 @@ func TestCreateResetToken(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") + _ = testutils.SetupUserData(db, "alice@example.com", "somepassword") // Execute dat := url.Values{} @@ -846,8 +824,7 @@ func TestUpdatePassword(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword") + user := testutils.SetupUserData(db, "alice@example.com", "oldpassword") // Execute dat := url.Values{} @@ -861,10 +838,9 @@ func TestUpdatePassword(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismsatch") - var account database.Account - testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&user), "finding account") - passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword")) + passwordErr := bcrypt.CompareHashAndPassword([]byte(user.Password.String), []byte("newpassword")) assert.Equal(t, passwordErr, nil, "Password mismatch") }) @@ -877,8 +853,7 @@ func TestUpdatePassword(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword") + u := testutils.SetupUserData(db, "alice@example.com", "oldpassword") // Execute dat := url.Values{} @@ -892,9 +867,9 @@ func TestUpdatePassword(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch") - var account database.Account - testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") - assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated") + var user database.User + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account") + assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated") }) t.Run("password too short", func(t *testing.T) { @@ -907,8 +882,7 @@ func TestUpdatePassword(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword") + u := testutils.SetupUserData(db, "alice@example.com", "oldpassword") // Execute dat := url.Values{} @@ -922,9 +896,9 @@ func TestUpdatePassword(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch") - var account database.Account - testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") - assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated") + var user database.User + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account") + assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated") }) t.Run("password confirmation mismatch", func(t *testing.T) { @@ -937,8 +911,7 @@ func TestUpdatePassword(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword") + u := testutils.SetupUserData(db, "alice@example.com", "oldpassword") // Execute dat := url.Values{} @@ -952,9 +925,9 @@ func TestUpdatePassword(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch") - var account database.Account - testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") - assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated") + var user database.User + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account") + assert.Equal(t, u.Password.String, user.Password.String, "password should not have been updated") }) } @@ -969,8 +942,7 @@ func TestUpdateEmail(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") + u := testutils.SetupUserData(db, "alice@example.com", "pass1234") // Execute dat := url.Values{} @@ -984,11 +956,10 @@ func TestUpdateEmail(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch") var user database.User - var account database.Account testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user") - testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding account") - assert.Equal(t, account.Email.String, "alice-new@example.com", "email mismatch") + assert.Equal(t, user.Email.String, "alice-new@example.com", "email mismatch") }) t.Run("password mismatch", func(t *testing.T) { @@ -1001,8 +972,7 @@ func TestUpdateEmail(t *testing.T) { server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData(db) - testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") + u := testutils.SetupUserData(db, "alice@example.com", "pass1234") // Execute dat := url.Values{} @@ -1016,11 +986,9 @@ func TestUpdateEmail(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch") var user database.User - var account database.Account testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user") - testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") - assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch") + assert.Equal(t, user.Email.String, "alice@example.com", "email mismatch") }) } diff --git a/pkg/server/database/database.go b/pkg/server/database/database.go index eaab7c50..f73d45af 100644 --- a/pkg/server/database/database.go +++ b/pkg/server/database/database.go @@ -36,7 +36,6 @@ var ( func InitSchema(db *gorm.DB) { if err := db.AutoMigrate( &User{}, - &Account{}, &Book{}, &Note{}, &Token{}, diff --git a/pkg/server/database/models.go b/pkg/server/database/models.go index 99e41c96..576dee7f 100644 --- a/pkg/server/database/models.go +++ b/pkg/server/database/models.go @@ -61,20 +61,13 @@ type Note struct { // User is a model for a user type User struct { Model - UUID string `json:"uuid" gorm:"type:text;index"` - Account Account `gorm:"foreignKey:UserID"` + UUID string `json:"uuid" gorm:"type:text;index"` + Email NullString `gorm:"index"` + Password NullString `json:"-"` LastLoginAt *time.Time `json:"-"` MaxUSN int `json:"-" gorm:"default:0"` } -// Account is a model for an account -type Account struct { - Model - UserID int `gorm:"index"` - Email NullString - Password NullString -} - // Token is a model for a token type Token struct { Model diff --git a/pkg/server/middleware/auth.go b/pkg/server/middleware/auth.go index f74d1efa..daf92dc3 100644 --- a/pkg/server/middleware/auth.go +++ b/pkg/server/middleware/auth.go @@ -67,8 +67,6 @@ type AuthParams struct { // Auth is an authentication middleware func Auth(db *gorm.DB, next http.HandlerFunc, p *AuthParams) http.HandlerFunc { - next = WithAccount(db, next) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, ok, err := AuthWithSession(db, r) if !ok { @@ -93,27 +91,6 @@ func Auth(db *gorm.DB, next http.HandlerFunc, p *AuthParams) http.HandlerFunc { ctx := context.WithUser(r.Context(), &user) next.ServeHTTP(w, r.WithContext(ctx)) }) - -} - -func WithAccount(db *gorm.DB, next http.HandlerFunc) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := context.User(r.Context()) - - var account database.Account - err := db.Where("user_id = ?", user.ID).First(&account).Error - if errors.Is(err, gorm.ErrRecordNotFound) { - DoError(w, "account not found", err, http.StatusForbidden) - return - } else if err != nil { - DoError(w, "finding account", err, http.StatusInternalServerError) - return - } - - ctx := context.WithAccount(r.Context(), &account) - - next.ServeHTTP(w, r.WithContext(ctx)) - }) } // TokenAuth is an authentication middleware with token diff --git a/pkg/server/middleware/auth_test.go b/pkg/server/middleware/auth_test.go index c0d94096..95c935a2 100644 --- a/pkg/server/middleware/auth_test.go +++ b/pkg/server/middleware/auth_test.go @@ -47,7 +47,7 @@ func TestGuestOnly(t *testing.T) { }) t.Run("logged in", func(t *testing.T) { - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") req := testutils.MakeReq(server.URL, "GET", "/", "") res := testutils.HTTPAuthDo(t, db, req, user) @@ -67,8 +67,7 @@ func TestGuestOnly(t *testing.T) { func TestAuth(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") session := database.Session{ Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", @@ -175,7 +174,7 @@ func TestAuth(t *testing.T) { func TestTokenAuth(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") tok := database.Token{ UserID: user.ID, Type: database.TokenTypeResetPassword, @@ -241,9 +240,8 @@ func TestWithAccount(t *testing.T) { w.WriteHeader(http.StatusOK) } - t.Run("user with account", func(t *testing.T) { - user := testutils.SetupUserData(db) - testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + t.Run("authenticated user", func(t *testing.T) { + user := testutils.SetupUserData(db, "alice@test.com", "pass1234") server := httptest.NewServer(Auth(db, handler, nil)) defer server.Close() @@ -253,17 +251,4 @@ func TestWithAccount(t *testing.T) { assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") }) - - t.Run("user without account", func(t *testing.T) { - user := testutils.SetupUserData(db) - // Note: not creating account for this user - - server := httptest.NewServer(Auth(db, handler, nil)) - defer server.Close() - - req := testutils.MakeReq(server.URL, "GET", "/", "") - res := testutils.HTTPAuthDo(t, db, req, user) - - assert.Equal(t, res.StatusCode, http.StatusForbidden, "status code mismatch") - }) } diff --git a/pkg/server/operations/notes_test.go b/pkg/server/operations/notes_test.go index f003f7cb..a2ccb023 100644 --- a/pkg/server/operations/notes_test.go +++ b/pkg/server/operations/notes_test.go @@ -30,8 +30,8 @@ import ( func TestGetNote(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - anotherUser := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") + anotherUser := testutils.SetupUserData(db, "another@test.com", "password123") b1 := database.Book{ UUID: testutils.MustUUID(t), @@ -98,7 +98,7 @@ func TestGetNote(t *testing.T) { func TestGetNote_nonexistent(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") b1 := database.Book{ UUID: testutils.MustUUID(t), diff --git a/pkg/server/permissions/permissions_test.go b/pkg/server/permissions/permissions_test.go index ed439d28..13f13f97 100644 --- a/pkg/server/permissions/permissions_test.go +++ b/pkg/server/permissions/permissions_test.go @@ -29,8 +29,8 @@ import ( func TestViewNote(t *testing.T) { db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData(db) - anotherUser := testutils.SetupUserData(db) + user := testutils.SetupUserData(db, "user@test.com", "password123") + anotherUser := testutils.SetupUserData(db, "another@test.com", "password123") b1 := database.Book{ UUID: testutils.MustUUID(t), diff --git a/pkg/server/session/session.go b/pkg/server/session/session.go index a9494c58..d7c3d0d8 100644 --- a/pkg/server/session/session.go +++ b/pkg/server/session/session.go @@ -29,9 +29,9 @@ type Session struct { } // New returns a new session for the given user -func New(user database.User, account database.Account) Session { +func New(user database.User) Session { return Session{ UUID: user.UUID, - Email: account.Email.String, + Email: user.Email.String, } } diff --git a/pkg/server/session/session_test.go b/pkg/server/session/session_test.go index dec674b7..dddfa18b 100644 --- a/pkg/server/session/session_test.go +++ b/pkg/server/session/session_test.go @@ -27,33 +27,34 @@ import ( ) func TestNew(t *testing.T) { - u1 := database.User{UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb"} - a1 := database.Account{Email: database.ToNullString("alice@example.com")} + u1 := database.User{ + UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb", + Email: database.ToNullString("alice@example.com"), + } - u2 := database.User{UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e"} - a2 := database.Account{Email: database.ToNullString("bob@example.com")} + u2 := database.User{ + UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e", + Email: database.ToNullString("bob@example.com"), + } testCases := []struct { - user database.User - account database.Account + user database.User }{ { - user: u1, - account: a1, + user: u1, }, { - user: u2, - account: a2, + user: u2, }, } for idx, tc := range testCases { t.Run(fmt.Sprintf("user %d", idx), func(t *testing.T) { // Execute - got := New(tc.user, tc.account) + got := New(tc.user) expected := Session{ UUID: tc.user.UUID, - Email: tc.account.Email.String, + Email: tc.user.Email.String, } assert.DeepEqual(t, got, expected, "result mismatch") diff --git a/pkg/server/testutils/main.go b/pkg/server/testutils/main.go index 08bc07de..45c547d0 100644 --- a/pkg/server/testutils/main.go +++ b/pkg/server/testutils/main.go @@ -82,9 +82,6 @@ func ClearData(db *gorm.DB) { if err := db.Where("1 = 1").Delete(&database.Session{}).Error; err != nil { panic(errors.Wrap(err, "Failed to clear sessions")) } - if err := db.Where("1 = 1").Delete(&database.Account{}).Error; err != nil { - panic(errors.Wrap(err, "Failed to clear accounts")) - } if err := db.Where("1 = 1").Delete(&database.User{}).Error; err != nil { panic(errors.Wrap(err, "Failed to clear users")) } @@ -99,15 +96,22 @@ func MustUUID(t *testing.T) string { return uuid } -// SetupUserData creates and returns a new user for testing purposes -func SetupUserData(db *gorm.DB) database.User { +// SetupUserData creates and returns a new user with email and password for testing purposes +func SetupUserData(db *gorm.DB, email, password string) database.User { uuid, err := helpers.GenUUID() if err != nil { panic(errors.Wrap(err, "Failed to generate UUID")) } + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + panic(errors.Wrap(err, "Failed to hash password")) + } + user := database.User{ - UUID: uuid, + UUID: uuid, + Email: database.ToNullString(email), + Password: database.ToNullString(string(hashedPassword)), } if err := db.Save(&user).Error; err != nil { @@ -117,28 +121,6 @@ func SetupUserData(db *gorm.DB) database.User { return user } -// SetupAccountData creates and returns a new account for the user -func SetupAccountData(db *gorm.DB, user database.User, email, password string) database.Account { - account := database.Account{ - UserID: user.ID, - } - if email != "" { - account.Email = database.ToNullString(email) - } - - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - panic(errors.Wrap(err, "Failed to hash password")) - } - account.Password = database.ToNullString(string(hashedPassword)) - - if err := db.Save(&account).Error; err != nil { - panic(errors.Wrap(err, "Failed to prepare account")) - } - - return account -} - // SetupSession creates and returns a new user session func SetupSession(db *gorm.DB, user database.User) database.Session { session := database.Session{ diff --git a/pkg/server/token/token_test.go b/pkg/server/token/token_test.go index 426ff3d1..c7469d4f 100644 --- a/pkg/server/token/token_test.go +++ b/pkg/server/token/token_test.go @@ -42,7 +42,7 @@ func TestCreate(t *testing.T) { db := testutils.InitMemoryDB(t) // Set up - u := testutils.SetupUserData(db) + u := testutils.SetupUserData(db, "user@test.com", "password123") // Execute tok, err := Create(db, u.ID, tc.kind) diff --git a/pkg/server/views/data.go b/pkg/server/views/data.go index 451cade8..98444364 100644 --- a/pkg/server/views/data.go +++ b/pkg/server/views/data.go @@ -50,9 +50,8 @@ type Alert struct { type Data struct { Alert *Alert // CSRF template.HTML - User *database.User - Account *database.Account - Yield map[string]interface{} + User *database.User + Yield map[string]interface{} } func getErrMessage(err error) string { diff --git a/pkg/server/views/templates/layouts/navbar.gohtml b/pkg/server/views/templates/layouts/navbar.gohtml index ea7c026f..b39a1e4e 100644 --- a/pkg/server/views/templates/layouts/navbar.gohtml +++ b/pkg/server/views/templates/layouts/navbar.gohtml @@ -41,7 +41,7 @@