From 505d900d31839cd732b9948ba6ce30efc84d4eb2 Mon Sep 17 00:00:00 2001 From: Sung Date: Sun, 19 Oct 2025 09:34:07 -0700 Subject: [PATCH] Manage users with server CLI --- .gitignore | 2 +- pkg/assert/prompt.go | 87 ++++++++++ pkg/cli/testutils/main.go | 85 +-------- pkg/cli/ui/terminal.go | 21 +-- pkg/e2e/server_test.go | 140 +++++++++++++++ pkg/prompt/prompt.go | 56 ++++++ pkg/prompt/prompt_test.go | 148 ++++++++++++++++ pkg/server/app/errors.go | 3 + pkg/server/app/testutils.go | 2 +- pkg/server/app/users.go | 99 ++++++++++- pkg/server/app/users_test.go | 299 ++++++++++++++++++++++++++++++++ pkg/server/cmd/helpers.go | 111 ++++++++++++ pkg/server/cmd/root.go | 60 +++++++ pkg/server/cmd/start.go | 91 ++++++++++ pkg/server/cmd/user.go | 190 ++++++++++++++++++++ pkg/server/cmd/user_test.go | 114 ++++++++++++ pkg/server/cmd/version.go | 29 ++++ pkg/server/controllers/users.go | 37 +--- pkg/server/log/log_test.go | 38 ++++ pkg/server/main.go | 151 +--------------- 20 files changed, 1484 insertions(+), 279 deletions(-) create mode 100644 pkg/assert/prompt.go create mode 100644 pkg/prompt/prompt.go create mode 100644 pkg/prompt/prompt_test.go create mode 100644 pkg/server/cmd/helpers.go create mode 100644 pkg/server/cmd/root.go create mode 100644 pkg/server/cmd/start.go create mode 100644 pkg/server/cmd/user.go create mode 100644 pkg/server/cmd/user_test.go create mode 100644 pkg/server/cmd/version.go create mode 100644 pkg/server/log/log_test.go diff --git a/.gitignore b/.gitignore index 57d82ddc..2847e75d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,4 @@ node_modules /test tmp *.db -server +/server diff --git a/pkg/assert/prompt.go b/pkg/assert/prompt.go new file mode 100644 index 00000000..d4ec5d25 --- /dev/null +++ b/pkg/assert/prompt.go @@ -0,0 +1,87 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Dnote. If not, see . + */ + +package assert + +import ( + "bufio" + "io" + "strings" + "time" + + "github.com/pkg/errors" +) + +// WaitForPrompt waits for an expected prompt to appear in stdout with a timeout. +// Returns an error if the prompt is not found within the timeout period. +// Handles prompts with or without newlines by reading character by character. +func WaitForPrompt(stdout io.Reader, expectedPrompt string, timeout time.Duration) error { + type result struct { + found bool + err error + } + resultCh := make(chan result, 1) + + go func() { + reader := bufio.NewReader(stdout) + var buffer strings.Builder + found := false + + for { + b, err := reader.ReadByte() + if err != nil { + resultCh <- result{found: found, err: err} + return + } + + buffer.WriteByte(b) + if strings.Contains(buffer.String(), expectedPrompt) { + found = true + break + } + } + + resultCh <- result{found: found, err: nil} + }() + + select { + case res := <-resultCh: + if res.err != nil && res.err != io.EOF { + return errors.Wrap(res.err, "reading stdout") + } + if !res.found { + return errors.Errorf("expected prompt '%s' not found in stdout", expectedPrompt) + } + return nil + case <-time.After(timeout): + return errors.Errorf("timeout waiting for prompt '%s'", expectedPrompt) + } +} + +// RespondToPrompt is a helper that waits for a prompt and sends a response. +func RespondToPrompt(stdout io.Reader, stdin io.WriteCloser, expectedPrompt, response string, timeout time.Duration) error { + if err := WaitForPrompt(stdout, expectedPrompt, timeout); err != nil { + return err + } + + if _, err := io.WriteString(stdin, response); err != nil { + return errors.Wrap(err, "writing response to stdin") + } + + return nil +} diff --git a/pkg/cli/testutils/main.go b/pkg/cli/testutils/main.go index ee853b48..6c6caa64 100644 --- a/pkg/cli/testutils/main.go +++ b/pkg/cli/testutils/main.go @@ -20,7 +20,6 @@ package testutils import ( - "bufio" "bytes" "encoding/json" "io" @@ -31,6 +30,7 @@ import ( "testing" "time" + "github.com/dnote/dnote/pkg/assert" "github.com/dnote/dnote/pkg/cli/consts" "github.com/dnote/dnote/pkg/cli/context" "github.com/dnote/dnote/pkg/cli/database" @@ -223,103 +223,32 @@ func MustWaitDnoteCmd(t *testing.T, opts RunDnoteCmdOptions, runFunc func(io.Rea return output } -// waitForPrompt waits for an expected prompt to appear in stdout with a timeout. -// Returns an error if the prompt is not found within the timeout period. -// Handles prompts with or without newlines by reading character by character. -func waitForPrompt(stdout io.Reader, expectedPrompt string, timeout time.Duration) error { - type result struct { - found bool - err error - } - resultCh := make(chan result, 1) - - go func() { - reader := bufio.NewReader(stdout) - var buffer strings.Builder - found := false - - for { - b, err := reader.ReadByte() - if err != nil { - resultCh <- result{found: found, err: err} - return - } - - buffer.WriteByte(b) - if strings.Contains(buffer.String(), expectedPrompt) { - found = true - break - } - } - - resultCh <- result{found: found, err: nil} - }() - - select { - case res := <-resultCh: - if res.err != nil && res.err != io.EOF { - return errors.Wrap(res.err, "reading stdout") - } - if !res.found { - return errors.Errorf("expected prompt '%s' not found in stdout", expectedPrompt) - } - return nil - case <-time.After(timeout): - return errors.Errorf("timeout waiting for prompt '%s'", expectedPrompt) - } -} - // MustWaitForPrompt waits for an expected prompt with a default timeout. // Fails the test if the prompt is not found or an error occurs. func MustWaitForPrompt(t *testing.T, stdout io.Reader, expectedPrompt string) { - if err := waitForPrompt(stdout, expectedPrompt, promptTimeout); err != nil { + if err := assert.WaitForPrompt(stdout, expectedPrompt, promptTimeout); err != nil { t.Fatal(err) } } -// userRespondToPrompt is a helper that waits for a prompt and sends a response. -func userRespondToPrompt(stdout io.Reader, stdin io.WriteCloser, expectedPrompt, response, action string) error { - if err := waitForPrompt(stdout, expectedPrompt, promptTimeout); err != nil { - return err - } - - if _, err := io.WriteString(stdin, response); err != nil { - return errors.Wrapf(err, "indicating %s in stdin", action) - } - - return nil -} - -// userConfirmOutput simulates confirmation from the user by writing to stdin. -// It waits for the expected prompt with a timeout to prevent deadlocks. -func userConfirmOutput(stdout io.Reader, stdin io.WriteCloser, expectedPrompt string) error { - return userRespondToPrompt(stdout, stdin, expectedPrompt, "y\n", "confirmation") -} - -// userCancelOutput simulates cancellation from the user by writing to stdin. -// It waits for the expected prompt with a timeout to prevent deadlocks. -func userCancelOutput(stdout io.Reader, stdin io.WriteCloser, expectedPrompt string) error { - return userRespondToPrompt(stdout, stdin, expectedPrompt, "n\n", "cancellation") -} - // ConfirmRemoveNote waits for prompt for removing a note and confirms. func ConfirmRemoveNote(stdout io.Reader, stdin io.WriteCloser) error { - return userConfirmOutput(stdout, stdin, PromptRemoveNote) + return assert.RespondToPrompt(stdout, stdin, PromptRemoveNote, "y\n", promptTimeout) } // ConfirmRemoveBook waits for prompt for deleting a book confirms. func ConfirmRemoveBook(stdout io.Reader, stdin io.WriteCloser) error { - return userConfirmOutput(stdout, stdin, PromptDeleteBook) + return assert.RespondToPrompt(stdout, stdin, PromptDeleteBook, "y\n", promptTimeout) } // UserConfirmEmptyServerSync waits for an empty server prompt and confirms. func UserConfirmEmptyServerSync(stdout io.Reader, stdin io.WriteCloser) error { - return userConfirmOutput(stdout, stdin, PromptEmptyServer) + return assert.RespondToPrompt(stdout, stdin, PromptEmptyServer, "y\n", promptTimeout) } -// UserCancelEmptyServerSync waits for an empty server prompt and confirms. +// UserCancelEmptyServerSync waits for an empty server prompt and cancels. func UserCancelEmptyServerSync(stdout io.Reader, stdin io.WriteCloser) error { - return userCancelOutput(stdout, stdin, PromptEmptyServer) + return assert.RespondToPrompt(stdout, stdin, PromptEmptyServer, "n\n", promptTimeout) } // UserContent simulates content from the user by writing to stdin. diff --git a/pkg/cli/ui/terminal.go b/pkg/cli/ui/terminal.go index ab52873d..899060c7 100644 --- a/pkg/cli/ui/terminal.go +++ b/pkg/cli/ui/terminal.go @@ -26,6 +26,7 @@ import ( "syscall" "github.com/dnote/dnote/pkg/cli/log" + "github.com/dnote/dnote/pkg/prompt" "github.com/pkg/errors" "golang.org/x/crypto/ssh/terminal" ) @@ -73,26 +74,16 @@ func PromptPassword(message string, dest *string) error { // Confirm prompts for user input to confirm a choice func Confirm(question string, optimistic bool) (bool, error) { - var choices string - if optimistic { - choices = "(Y/n)" - } else { - choices = "(y/N)" - } + message := prompt.FormatQuestion(question, optimistic) - message := fmt.Sprintf("%s %s", question, choices) + // Use log.Askf for colored prompt in CLI + log.Askf(message, false) - var input string - if err := PromptInput(message, &input); err != nil { + confirmed, err := prompt.ReadYesNo(os.Stdin, optimistic) + if err != nil { return false, errors.Wrap(err, "Failed to get user input") } - confirmed := input == "y" - - if optimistic { - confirmed = confirmed || input == "" - } - return confirmed, nil } diff --git a/pkg/e2e/server_test.go b/pkg/e2e/server_test.go index c37b7320..cfe6a711 100644 --- a/pkg/e2e/server_test.go +++ b/pkg/e2e/server_test.go @@ -19,6 +19,7 @@ package main import ( + "bytes" "fmt" "net/http" "os" @@ -181,3 +182,142 @@ func TestServerUnknownCommand(t *testing.T) { assert.Equal(t, strings.Contains(outputStr, "Unknown command"), true, "output should contain unknown command message") assert.Equal(t, strings.Contains(outputStr, "Dnote server - a simple command line notebook"), true, "output should show help") } + +func TestServerUserCreate(t *testing.T) { + tmpDB := t.TempDir() + "/test.db" + + cmd := exec.Command(testServerBinary, "user", "create", + "--dbPath", tmpDB, + "--email", "test@example.com", + "--password", "password123") + output, err := cmd.CombinedOutput() + + if err != nil { + t.Fatalf("user create failed: %v\nOutput: %s", err, output) + } + + outputStr := string(output) + assert.Equal(t, strings.Contains(outputStr, "User created successfully"), true, "output should show success message") + assert.Equal(t, strings.Contains(outputStr, "test@example.com"), true, "output should show email") + + // Verify user exists in database + db, err := gorm.Open(sqlite.Open(tmpDB), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer func() { + sqlDB, _ := db.DB() + sqlDB.Close() + }() + + var count int64 + db.Table("users").Count(&count) + assert.Equal(t, count, int64(1), "should have created 1 user") +} + +func TestServerUserCreateShortPassword(t *testing.T) { + tmpDB := t.TempDir() + "/test.db" + + cmd := exec.Command(testServerBinary, "user", "create", + "--dbPath", tmpDB, + "--email", "test@example.com", + "--password", "short") + output, err := cmd.CombinedOutput() + + // Should fail with short password + if err == nil { + t.Fatal("expected command to fail with short password") + } + + outputStr := string(output) + assert.Equal(t, strings.Contains(outputStr, "password should be longer than 8 characters"), true, "output should show password error") +} + +func TestServerUserResetPassword(t *testing.T) { + tmpDB := t.TempDir() + "/test.db" + + // Create user first + createCmd := exec.Command(testServerBinary, "user", "create", + "--dbPath", tmpDB, + "--email", "test@example.com", + "--password", "oldpassword123") + if output, err := createCmd.CombinedOutput(); err != nil { + t.Fatalf("failed to create user: %v\nOutput: %s", err, output) + } + + // Reset password + resetCmd := exec.Command(testServerBinary, "user", "reset-password", + "--dbPath", tmpDB, + "--email", "test@example.com", + "--password", "newpassword123") + output, err := resetCmd.CombinedOutput() + + if err != nil { + t.Fatalf("reset-password failed: %v\nOutput: %s", err, output) + } + + outputStr := string(output) + assert.Equal(t, strings.Contains(outputStr, "Password reset successfully"), true, "output should show success message") +} + +func TestServerUserRemove(t *testing.T) { + tmpDB := t.TempDir() + "/test.db" + + // Create user first + createCmd := exec.Command(testServerBinary, "user", "create", + "--dbPath", tmpDB, + "--email", "test@example.com", + "--password", "password123") + if output, err := createCmd.CombinedOutput(); err != nil { + t.Fatalf("failed to create user: %v\nOutput: %s", err, output) + } + + // Remove user with confirmation + removeCmd := exec.Command(testServerBinary, "user", "remove", + "--dbPath", tmpDB, + "--email", "test@example.com") + + // Pipe "y" to stdin to confirm removal + stdin, err := removeCmd.StdinPipe() + if err != nil { + t.Fatalf("failed to create stdin pipe: %v", err) + } + + // Capture output + stdout, err := removeCmd.StdoutPipe() + if err != nil { + t.Fatalf("failed to create stdout pipe: %v", err) + } + + var stderr bytes.Buffer + removeCmd.Stderr = &stderr + + // Start command + if err := removeCmd.Start(); err != nil { + t.Fatalf("failed to start remove command: %v", err) + } + + // Wait for prompt and send "y" to confirm + if err := assert.RespondToPrompt(stdout, stdin, "Remove user test@example.com?", "y\n", 10*time.Second); err != nil { + t.Fatalf("failed to confirm removal: %v", err) + } + + // Wait for command to finish + if err := removeCmd.Wait(); err != nil { + t.Fatalf("user remove failed: %v\nStderr: %s", err, stderr.String()) + } + + // Verify user was removed + db, err := gorm.Open(sqlite.Open(tmpDB), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer func() { + sqlDB, _ := db.DB() + sqlDB.Close() + }() + + var count int64 + db.Table("users").Count(&count) + assert.Equal(t, count, int64(0), "should have 0 users after removal") +} diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go new file mode 100644 index 00000000..1d413a27 --- /dev/null +++ b/pkg/prompt/prompt.go @@ -0,0 +1,56 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +// Package prompt provides utilities for interactive yes/no prompts +package prompt + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +// FormatQuestion formats a yes/no question with the appropriate choice indicator +func FormatQuestion(question string, optimistic bool) string { + choices := "(y/N)" + if optimistic { + choices = "(Y/n)" + } + return fmt.Sprintf("%s %s", question, choices) +} + +// ReadYesNo reads and parses a yes/no response from the given reader. +// Returns true if confirmed, respecting optimistic mode. +// In optimistic mode, empty input is treated as confirmation. +func ReadYesNo(r io.Reader, optimistic bool) (bool, error) { + reader := bufio.NewReader(r) + input, err := reader.ReadString('\n') + if err != nil { + return false, err + } + + input = strings.ToLower(strings.TrimSpace(input)) + confirmed := input == "y" + + if optimistic { + confirmed = confirmed || input == "" + } + + return confirmed, nil +} diff --git a/pkg/prompt/prompt_test.go b/pkg/prompt/prompt_test.go new file mode 100644 index 00000000..f0df9480 --- /dev/null +++ b/pkg/prompt/prompt_test.go @@ -0,0 +1,148 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package prompt + +import ( + "strings" + "testing" + + "github.com/dnote/dnote/pkg/assert" +) + +func TestFormatQuestion(t *testing.T) { + testCases := []struct { + question string + optimistic bool + expected string + }{ + { + question: "Are you sure?", + optimistic: false, + expected: "Are you sure? (y/N)", + }, + { + question: "Continue?", + optimistic: true, + expected: "Continue? (Y/n)", + }, + } + + for _, tc := range testCases { + t.Run(tc.question, func(t *testing.T) { + result := FormatQuestion(tc.question, tc.optimistic) + assert.Equal(t, result, tc.expected, "formatted question mismatch") + }) + } +} + +func TestReadYesNo(t *testing.T) { + testCases := []struct { + name string + input string + optimistic bool + expected bool + }{ + { + name: "pessimistic with y", + input: "y\n", + optimistic: false, + expected: true, + }, + { + name: "pessimistic with Y (uppercase)", + input: "Y\n", + optimistic: false, + expected: true, + }, + { + name: "pessimistic with n", + input: "n\n", + optimistic: false, + expected: false, + }, + { + name: "pessimistic with empty", + input: "\n", + optimistic: false, + expected: false, + }, + { + name: "pessimistic with whitespace", + input: " \n", + optimistic: false, + expected: false, + }, + { + name: "optimistic with y", + input: "y\n", + optimistic: true, + expected: true, + }, + { + name: "optimistic with n", + input: "n\n", + optimistic: true, + expected: false, + }, + { + name: "optimistic with empty", + input: "\n", + optimistic: true, + expected: true, + }, + { + name: "optimistic with whitespace", + input: " \n", + optimistic: true, + expected: true, + }, + { + name: "invalid input defaults to no", + input: "maybe\n", + optimistic: false, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a reader with test input + reader := strings.NewReader(tc.input) + + // Test ReadYesNo + result, err := ReadYesNo(reader, tc.optimistic) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + assert.Equal(t, result, tc.expected, "ReadYesNo result mismatch") + }) + } +} + +func TestReadYesNo_Error(t *testing.T) { + // Test error case with EOF (empty reader) + reader := strings.NewReader("") + + _, err := ReadYesNo(reader, false) + if err == nil { + t.Fatal("expected error when reading from empty reader") + } +} + diff --git a/pkg/server/app/errors.go b/pkg/server/app/errors.go index 895fc40f..48f193b9 100644 --- a/pkg/server/app/errors.go +++ b/pkg/server/app/errors.go @@ -79,4 +79,7 @@ var ( ErrInvalidPassword appError = "Invalid currnet password." // ErrEmailTooLong is an error for email length exceeding the limit ErrEmailTooLong appError = "Email is too long." + + // ErrUserHasExistingResources is an error for attempting to remove a user with existing notes or books + ErrUserHasExistingResources appError = "cannot remove user with existing notes or books" ) diff --git a/pkg/server/app/testutils.go b/pkg/server/app/testutils.go index 06664c5f..45dbd5bb 100644 --- a/pkg/server/app/testutils.go +++ b/pkg/server/app/testutils.go @@ -36,7 +36,7 @@ func NewTest() App { WebURL: "http://127.0.0.0.1", Port: "3000", DisableRegistration: false, - DBPath: ":memory:", + DBPath: "", AssetBaseURL: "", } } diff --git a/pkg/server/app/users.go b/pkg/server/app/users.go index 5c9d7f1f..b3993d55 100644 --- a/pkg/server/app/users.go +++ b/pkg/server/app/users.go @@ -29,6 +29,15 @@ import ( "gorm.io/gorm" ) +// validatePassword validates a password +func validatePassword(password string) error { + if len(password) < 8 { + return ErrPasswordTooShort + } + + return nil +} + // TouchLastLoginAt updates the last login timestamp func (a *App) TouchLastLoginAt(user database.User, tx *gorm.DB) error { t := a.Clock.Now() @@ -45,8 +54,8 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d return database.User{}, ErrEmailRequired } - if len(password) < 8 { - return database.User{}, ErrPasswordTooShort + if err := validatePassword(password); err != nil { + return database.User{}, err } if password != passwordConfirmation { @@ -102,8 +111,8 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d return user, nil } -// Authenticate authenticates a user -func (a *App) Authenticate(email, password string) (*database.User, error) { +// GetAccountByEmail finds an account by email +func (a *App) GetAccountByEmail(email string) (*database.Account, error) { var account database.Account err := a.DB.Where("email = ?", email).First(&account).Error if errors.Is(err, gorm.ErrRecordNotFound) { @@ -112,6 +121,16 @@ func (a *App) Authenticate(email, password string) (*database.User, error) { return nil, err } + return &account, nil +} + +// Authenticate authenticates a user +func (a *App) Authenticate(email, password string) (*database.User, error) { + account, err := a.GetAccountByEmail(email) + if err != nil { + return nil, err + } + err = bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(password)) if err != nil { return nil, ErrLoginInvalid @@ -126,6 +145,78 @@ func (a *App) Authenticate(email, password string) (*database.User, error) { return &user, nil } +// UpdateAccountPassword updates an account's password with validation +func UpdateAccountPassword(db *gorm.DB, account *database.Account, newPassword string) error { + // Validate password + if err := validatePassword(newPassword); err != nil { + return err + } + + // Hash the password + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) + if err != nil { + return pkgErrors.Wrap(err, "hashing password") + } + + // Update the password + if err := db.Model(&account).Update("password", string(hashedPassword)).Error; err != nil { + return pkgErrors.Wrap(err, "updating password") + } + + return nil +} + +// RemoveUser removes a user and their account from the system +// Returns an error if the user has any notes or books +func (a *App) RemoveUser(email string) error { + // Find the account and user + account, err := a.GetAccountByEmail(email) + if err != nil { + return err + } + + // Check if user has any notes + var noteCount int64 + if err := a.DB.Model(&database.Note{}).Where("user_id = ? AND deleted = ?", account.UserID, false).Count(¬eCount).Error; err != nil { + return pkgErrors.Wrap(err, "counting notes") + } + if noteCount > 0 { + return ErrUserHasExistingResources + } + + // Check if user has any books + var bookCount int64 + if err := a.DB.Model(&database.Book{}).Where("user_id = ? AND deleted = ?", account.UserID, false).Count(&bookCount).Error; err != nil { + return pkgErrors.Wrap(err, "counting books") + } + if bookCount > 0 { + return ErrUserHasExistingResources + } + + // Delete account and user in a transaction + tx := a.DB.Begin() + + if err := tx.Delete(&account).Error; err != nil { + tx.Rollback() + return pkgErrors.Wrap(err, "deleting account") + } + + var user database.User + if err := tx.Where("id = ?", account.UserID).First(&user).Error; err != nil { + tx.Rollback() + return pkgErrors.Wrap(err, "finding user") + } + + if err := tx.Delete(&user).Error; err != nil { + tx.Rollback() + return pkgErrors.Wrap(err, "deleting user") + } + + tx.Commit() + + return nil +} + // SignIn signs in a user func (a *App) SignIn(user *database.User) (*database.Session, error) { err := a.TouchLastLoginAt(*user, a.DB) diff --git a/pkg/server/app/users_test.go b/pkg/server/app/users_test.go index 45c43514..a4c3a60d 100644 --- a/pkg/server/app/users_test.go +++ b/pkg/server/app/users_test.go @@ -28,6 +28,42 @@ import ( "golang.org/x/crypto/bcrypt" ) +func TestValidatePassword(t *testing.T) { + testCases := []struct { + name string + password string + wantErr error + }{ + { + name: "valid password", + password: "password123", + wantErr: nil, + }, + { + name: "valid password exactly 8 chars", + password: "12345678", + wantErr: nil, + }, + { + name: "password too short", + password: "1234567", + wantErr: ErrPasswordTooShort, + }, + { + name: "empty password", + password: "", + wantErr: ErrPasswordTooShort, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validatePassword(tc.password) + assert.Equal(t, err, tc.wantErr, "error mismatch") + }) + } +} + func TestCreateUser_ProValue(t *testing.T) { db := testutils.InitMemoryDB(t) @@ -46,6 +82,36 @@ func TestCreateUser_ProValue(t *testing.T) { } +func TestGetAccountByEmail(t *testing.T) { + t.Run("success", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "password123") + + a := NewTest() + a.DB = db + + account, err := a.GetAccountByEmail("alice@example.com") + + assert.Equal(t, err, nil, "should not error") + assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch") + assert.Equal(t, account.UserID, user.ID, "user ID mismatch") + }) + + t.Run("not found", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + a := NewTest() + a.DB = db + + account, err := a.GetAccountByEmail("nonexistent@example.com") + + assert.Equal(t, err, ErrNotFound, "should return ErrNotFound") + assert.Equal(t, account, (*database.Account)(nil), "account should be nil") + }) +} + func TestCreateUser(t *testing.T) { t.Run("success", func(t *testing.T) { db := testutils.InitMemoryDB(t) @@ -92,3 +158,236 @@ func TestCreateUser(t *testing.T) { assert.Equal(t, accountCount, int64(1), "account count mismatch") }) } + +func TestUpdateAccountPassword(t *testing.T) { + t.Run("success", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + + err := UpdateAccountPassword(db, &account, "newpassword123") + + assert.Equal(t, err, nil, "should not error") + + // Verify password was updated in database + var updatedAccount database.Account + testutils.MustExec(t, db.Where("id = ?", account.ID).First(&updatedAccount), "finding updated account") + + // Verify new password works + passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123")) + assert.Equal(t, passwordErr, nil, "New password should match") + + // Verify old password no longer works + oldPasswordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("oldpassword123")) + assert.NotEqual(t, oldPasswordErr, nil, "Old password should not match") + }) + + t.Run("password too short", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + + err := UpdateAccountPassword(db, &account, "short") + + assert.Equal(t, err, ErrPasswordTooShort, "should return ErrPasswordTooShort") + + // Verify password was NOT updated in database + var unchangedAccount database.Account + testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account") + + // Verify old password still works + passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123")) + assert.Equal(t, passwordErr, nil, "Old password should still match") + }) + + t.Run("empty password", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + + err := UpdateAccountPassword(db, &account, "") + + assert.Equal(t, err, ErrPasswordTooShort, "should return ErrPasswordTooShort") + + // Verify password was NOT updated in database + var unchangedAccount database.Account + testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account") + + // Verify old password still works + passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123")) + assert.Equal(t, passwordErr, nil, "Old password should still match") + }) + + t.Run("transaction rollback", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + + // Start a transaction and rollback to verify UpdateAccountPassword respects transactions + tx := db.Begin() + err := UpdateAccountPassword(tx, &account, "newpassword123") + assert.Equal(t, err, nil, "should not error") + tx.Rollback() + + // Verify password was NOT updated after rollback + var unchangedAccount database.Account + testutils.MustExec(t, db.Where("id = ?", account.ID).First(&unchangedAccount), "finding unchanged account") + + // Verify old password still works + passwordErr := bcrypt.CompareHashAndPassword([]byte(unchangedAccount.Password.String), []byte("oldpassword123")) + assert.Equal(t, passwordErr, nil, "Old password should still match after rollback") + }) + + t.Run("transaction commit", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + account := testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword123") + + // Start a transaction and commit to verify UpdateAccountPassword respects transactions + tx := db.Begin() + err := UpdateAccountPassword(tx, &account, "newpassword123") + assert.Equal(t, err, nil, "should not error") + tx.Commit() + + // Verify password was updated after commit + var updatedAccount database.Account + testutils.MustExec(t, db.Where("id = ?", account.ID).First(&updatedAccount), "finding updated account") + + // Verify new password works + passwordErr := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123")) + assert.Equal(t, passwordErr, nil, "New password should match after commit") + }) +} + +func TestRemoveUser(t *testing.T) { + t.Run("success", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "password123") + + a := NewTest() + a.DB = db + + err := a.RemoveUser("alice@example.com") + + assert.Equal(t, err, nil, "should not error") + + // Verify user was deleted + var userCount int64 + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users") + assert.Equal(t, userCount, int64(0), "user should be deleted") + + // Verify account was deleted + var accountCount int64 + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts") + assert.Equal(t, accountCount, int64(0), "account should be deleted") + }) + + t.Run("user not found", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + a := NewTest() + a.DB = db + + err := a.RemoveUser("nonexistent@example.com") + + assert.Equal(t, err, ErrNotFound, "should return ErrNotFound") + }) + + t.Run("user has notes", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "password123") + + book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false} + testutils.MustExec(t, db.Save(&book), "creating book") + + note := database.Note{UserID: user.ID, BookUUID: book.UUID, Body: "test note", Deleted: false} + testutils.MustExec(t, db.Save(¬e), "creating note") + + a := NewTest() + a.DB = db + + err := a.RemoveUser("alice@example.com") + + assert.Equal(t, err, ErrUserHasExistingResources, "should return ErrUserHasExistingResources") + + // Verify user was NOT deleted + var userCount int64 + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users") + assert.Equal(t, userCount, int64(1), "user should not be deleted") + + // Verify account was NOT deleted + var accountCount int64 + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts") + assert.Equal(t, accountCount, int64(1), "account should not be deleted") + }) + + t.Run("user has books", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "password123") + + book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false} + testutils.MustExec(t, db.Save(&book), "creating book") + + a := NewTest() + a.DB = db + + err := a.RemoveUser("alice@example.com") + + assert.Equal(t, err, ErrUserHasExistingResources, "should return ErrUserHasExistingResources") + + // Verify user was NOT deleted + var userCount int64 + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users") + assert.Equal(t, userCount, int64(1), "user should not be deleted") + + // Verify account was NOT deleted + var accountCount int64 + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts") + assert.Equal(t, accountCount, int64(1), "account should not be deleted") + }) + + t.Run("user has deleted notes and books", func(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "password123") + + book := database.Book{UserID: user.ID, Label: "testbook", Deleted: false} + testutils.MustExec(t, db.Save(&book), "creating book") + + note := database.Note{UserID: user.ID, BookUUID: book.UUID, Body: "test note", Deleted: false} + testutils.MustExec(t, db.Save(¬e), "creating note") + + // Soft delete the note and book + testutils.MustExec(t, db.Model(¬e).Update("deleted", true), "soft deleting note") + testutils.MustExec(t, db.Model(&book).Update("deleted", true), "soft deleting book") + + a := NewTest() + a.DB = db + + err := a.RemoveUser("alice@example.com") + + assert.Equal(t, err, nil, "should not error when user only has deleted notes and books") + + // Verify user was deleted + var userCount int64 + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting users") + assert.Equal(t, userCount, int64(0), "user should be deleted") + + // Verify account was deleted + var accountCount int64 + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting accounts") + assert.Equal(t, accountCount, int64(0), "account should be deleted") + }) +} diff --git a/pkg/server/cmd/helpers.go b/pkg/server/cmd/helpers.go new file mode 100644 index 00000000..a22c8721 --- /dev/null +++ b/pkg/server/cmd/helpers.go @@ -0,0 +1,111 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package cmd + +import ( + "flag" + "fmt" + "os" + + "github.com/dnote/dnote/pkg/clock" + "github.com/dnote/dnote/pkg/server/app" + "github.com/dnote/dnote/pkg/server/config" + "github.com/dnote/dnote/pkg/server/database" + "github.com/dnote/dnote/pkg/server/log" + "github.com/dnote/dnote/pkg/server/mailer" + "gorm.io/gorm" +) + +func initDB(dbPath string) *gorm.DB { + db := database.Open(dbPath) + database.InitSchema(db) + database.Migrate(db) + + return db +} + +func initApp(cfg config.Config) app.App { + db := initDB(cfg.DBPath) + + emailBackend, err := mailer.NewDefaultBackend(cfg.IsProd()) + if err != nil { + emailBackend = &mailer.DefaultBackend{Enabled: false} + } else { + log.Info("Email backend configured") + } + + return app.App{ + DB: db, + Clock: clock.New(), + EmailTemplates: mailer.NewTemplates(), + EmailBackend: emailBackend, + HTTP500Page: cfg.HTTP500Page, + AppEnv: cfg.AppEnv, + WebURL: cfg.WebURL, + DisableRegistration: cfg.DisableRegistration, + Port: cfg.Port, + DBPath: cfg.DBPath, + AssetBaseURL: cfg.AssetBaseURL, + } +} + +// setupFlagSet creates a FlagSet with standard usage format +func setupFlagSet(name, usageCmd string) *flag.FlagSet { + fs := flag.NewFlagSet(name, flag.ExitOnError) + fs.Usage = func() { + fmt.Printf(`Usage: + %s [flags] + +Flags: +`, usageCmd) + fs.PrintDefaults() + } + return fs +} + +// requireString validates that a required string flag is not empty +func requireString(fs *flag.FlagSet, value, fieldName string) { + if value == "" { + fmt.Printf("Error: %s is required\n", fieldName) + fs.Usage() + os.Exit(1) + } +} + +// setupAppWithDB creates config, initializes app, and returns cleanup function +func setupAppWithDB(fs *flag.FlagSet, dbPath string) (*app.App, func()) { + cfg, err := config.New(config.Params{ + DBPath: dbPath, + }) + if err != nil { + fmt.Printf("Error: %s\n\n", err) + fs.Usage() + os.Exit(1) + } + + a := initApp(cfg) + cleanup := func() { + sqlDB, err := a.DB.DB() + if err == nil { + sqlDB.Close() + } + } + + return &a, cleanup +} diff --git a/pkg/server/cmd/root.go b/pkg/server/cmd/root.go new file mode 100644 index 00000000..1ed2129f --- /dev/null +++ b/pkg/server/cmd/root.go @@ -0,0 +1,60 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package cmd + +import ( + "fmt" + "os" +) + +func rootCmd() { + fmt.Printf(`Dnote server - a simple command line notebook + +Usage: + dnote-server [command] [flags] + +Available commands: + start: Start the server (use 'dnote-server start --help' for flags) + user: Manage users (use 'dnote-server user' for subcommands) + version: Print the version +`) +} + +// Execute is the main entry point for the CLI +func Execute() { + if len(os.Args) < 2 { + rootCmd() + return + } + + cmd := os.Args[1] + + switch cmd { + case "start": + startCmd(os.Args[2:]) + case "user": + userCmd(os.Args[2:]) + case "version": + versionCmd() + default: + fmt.Printf("Unknown command %s\n", cmd) + rootCmd() + os.Exit(1) + } +} diff --git a/pkg/server/cmd/start.go b/pkg/server/cmd/start.go new file mode 100644 index 00000000..cbbc60ed --- /dev/null +++ b/pkg/server/cmd/start.go @@ -0,0 +1,91 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package cmd + +import ( + "fmt" + "net/http" + "os" + + "github.com/dnote/dnote/pkg/server/buildinfo" + "github.com/dnote/dnote/pkg/server/config" + "github.com/dnote/dnote/pkg/server/controllers" + "github.com/dnote/dnote/pkg/server/log" + "github.com/pkg/errors" +) + +func startCmd(args []string) { + fs := setupFlagSet("start", "dnote-server start") + + appEnv := fs.String("appEnv", "", "Application environment (env: APP_ENV, default: PRODUCTION)") + port := fs.String("port", "", "Server port (env: PORT, default: 3001)") + webURL := fs.String("webUrl", "", "Full URL to server without trailing slash (env: WebURL, default: http://localhost:3001)") + dbPath := fs.String("dbPath", "", "Path to SQLite database file (env: DBPath, default: $XDG_DATA_HOME/dnote/server.db)") + disableRegistration := fs.Bool("disableRegistration", false, "Disable user registration (env: DisableRegistration, default: false)") + logLevel := fs.String("logLevel", "", "Log level: debug, info, warn, or error (env: LOG_LEVEL, default: info)") + + fs.Parse(args) + + cfg, err := config.New(config.Params{ + AppEnv: *appEnv, + Port: *port, + WebURL: *webURL, + DBPath: *dbPath, + DisableRegistration: *disableRegistration, + LogLevel: *logLevel, + }) + if err != nil { + fmt.Printf("Error: %s\n\n", err) + fs.Usage() + os.Exit(1) + } + + // Set log level + log.SetLevel(cfg.LogLevel) + + app := initApp(cfg) + defer func() { + sqlDB, err := app.DB.DB() + if err == nil { + sqlDB.Close() + } + }() + + ctl := controllers.New(&app) + rc := controllers.RouteConfig{ + WebRoutes: controllers.NewWebRoutes(&app, ctl), + APIRoutes: controllers.NewAPIRoutes(&app, ctl), + Controllers: ctl, + } + + r, err := controllers.NewRouter(&app, rc) + if err != nil { + panic(errors.Wrap(err, "initializing router")) + } + + log.WithFields(log.Fields{ + "version": buildinfo.Version, + "port": cfg.Port, + }).Info("Dnote server starting") + + if err := http.ListenAndServe(fmt.Sprintf(":%s", cfg.Port), r); err != nil { + log.ErrorWrap(err, "server failed") + os.Exit(1) + } +} diff --git a/pkg/server/cmd/user.go b/pkg/server/cmd/user.go new file mode 100644 index 00000000..6123cdae --- /dev/null +++ b/pkg/server/cmd/user.go @@ -0,0 +1,190 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package cmd + +import ( + "fmt" + "io" + "os" + + "github.com/dnote/dnote/pkg/prompt" + "github.com/dnote/dnote/pkg/server/app" + "github.com/dnote/dnote/pkg/server/log" + "github.com/pkg/errors" +) + +// confirm prompts for user input to confirm a choice +func confirm(r io.Reader, question string, optimistic bool) (bool, error) { + message := prompt.FormatQuestion(question, optimistic) + fmt.Print(message + " ") + + confirmed, err := prompt.ReadYesNo(r, optimistic) + if err != nil { + return false, errors.Wrap(err, "reading stdin") + } + + return confirmed, nil +} + +func userCreateCmd(args []string) { + fs := setupFlagSet("create", "dnote-server user create") + + email := fs.String("email", "", "User email address (required)") + password := fs.String("password", "", "User password (required)") + dbPath := fs.String("dbPath", "", "Path to SQLite database file (env: DBPath, default: $XDG_DATA_HOME/dnote/server.db)") + + fs.Parse(args) + + requireString(fs, *email, "email") + requireString(fs, *password, "password") + + a, cleanup := setupAppWithDB(fs, *dbPath) + defer cleanup() + + _, err := a.CreateUser(*email, *password, *password) + if err != nil { + log.ErrorWrap(err, "creating user") + os.Exit(1) + } + + fmt.Printf("User created successfully\n") + fmt.Printf("Email: %s\n", *email) +} + +func userRemoveCmd(args []string, stdin io.Reader) { + fs := setupFlagSet("remove", "dnote-server user remove") + + email := fs.String("email", "", "User email address (required)") + dbPath := fs.String("dbPath", "", "Path to SQLite database file (env: DBPath, default: $XDG_DATA_HOME/dnote/server.db)") + + fs.Parse(args) + + requireString(fs, *email, "email") + + a, cleanup := setupAppWithDB(fs, *dbPath) + defer cleanup() + + // Check if user exists first + _, err := a.GetAccountByEmail(*email) + if err != nil { + if errors.Is(err, app.ErrNotFound) { + fmt.Printf("Error: user with email %s not found\n", *email) + } else { + log.ErrorWrap(err, "finding account") + } + os.Exit(1) + } + + // Show confirmation prompt + ok, err := confirm(stdin, fmt.Sprintf("Remove user %s?", *email), false) + if err != nil { + log.ErrorWrap(err, "getting confirmation") + os.Exit(1) + } + if !ok { + fmt.Println("Aborted by user") + os.Exit(0) + } + + // Remove the user + if err := a.RemoveUser(*email); err != nil { + if errors.Is(err, app.ErrNotFound) { + fmt.Printf("Error: user with email %s not found\n", *email) + } else if errors.Is(err, app.ErrUserHasExistingResources) { + fmt.Printf("Error: %s\n", err) + } else { + log.ErrorWrap(err, "removing user") + } + os.Exit(1) + } + + fmt.Printf("User removed successfully\n") + fmt.Printf("Email: %s\n", *email) +} + +func userResetPasswordCmd(args []string) { + fs := setupFlagSet("reset-password", "dnote-server user reset-password") + + email := fs.String("email", "", "User email address (required)") + password := fs.String("password", "", "New password (required)") + dbPath := fs.String("dbPath", "", "Path to SQLite database file (env: DBPath, default: $XDG_DATA_HOME/dnote/server.db)") + + fs.Parse(args) + + requireString(fs, *email, "email") + requireString(fs, *password, "password") + + a, cleanup := setupAppWithDB(fs, *dbPath) + defer cleanup() + + // Find the account + account, err := a.GetAccountByEmail(*email) + if err != nil { + if errors.Is(err, app.ErrNotFound) { + fmt.Printf("Error: user with email %s not found\n", *email) + } else { + log.ErrorWrap(err, "finding account") + } + os.Exit(1) + } + + // Update the password + if err := app.UpdateAccountPassword(a.DB, account, *password); err != nil { + log.ErrorWrap(err, "updating password") + os.Exit(1) + } + + fmt.Printf("Password reset successfully\n") + fmt.Printf("Email: %s\n", *email) +} + +func userCmd(args []string) { + if len(args) < 1 { + fmt.Println(`Usage: + dnote-server user [command] + +Available commands: + create: Create a new user + remove: Remove a user + reset-password: Reset a user's password`) + os.Exit(1) + } + + subcommand := args[0] + subArgs := []string{} + if len(args) > 1 { + subArgs = args[1:] + } + + switch subcommand { + case "create": + userCreateCmd(subArgs) + case "remove": + userRemoveCmd(subArgs, os.Stdin) + case "reset-password": + userResetPasswordCmd(subArgs) + default: + fmt.Printf("Unknown subcommand: %s\n\n", subcommand) + fmt.Println(`Available commands: + create: Create a new user + remove: Remove a user (only if they have no notes or books) + reset-password: Reset a user's password`) + os.Exit(1) + } +} diff --git a/pkg/server/cmd/user_test.go b/pkg/server/cmd/user_test.go new file mode 100644 index 00000000..d3536d52 --- /dev/null +++ b/pkg/server/cmd/user_test.go @@ -0,0 +1,114 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package cmd + +import ( + "strings" + "testing" + + "github.com/dnote/dnote/pkg/assert" + "github.com/dnote/dnote/pkg/server/database" + "github.com/dnote/dnote/pkg/server/testutils" + "golang.org/x/crypto/bcrypt" +) + +func TestUserCreateCmd(t *testing.T) { + tmpDB := t.TempDir() + "/test.db" + + // Call the function directly + userCreateCmd([]string{"--dbPath", tmpDB, "--email", "test@example.com", "--password", "password123"}) + + // Verify user was created in database + db := testutils.InitDB(tmpDB) + defer func() { + sqlDB, _ := db.DB() + sqlDB.Close() + }() + + var count int64 + testutils.MustExec(t, db.Model(&database.User{}).Count(&count), "counting users") + assert.Equal(t, count, int64(1), "should have 1 user") + + var account database.Account + testutils.MustExec(t, db.Where("email = ?", "test@example.com").First(&account), "finding account") + assert.Equal(t, account.Email.String, "test@example.com", "email mismatch") +} + +func TestUserRemoveCmd(t *testing.T) { + tmpDB := t.TempDir() + "/test.db" + + // Create a user first + db := testutils.InitDB(tmpDB) + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "test@example.com", "password123") + sqlDB, _ := db.DB() + sqlDB.Close() + + // Remove the user with mock stdin that responds "y" + mockStdin := strings.NewReader("y\n") + userRemoveCmd([]string{"--dbPath", tmpDB, "--email", "test@example.com"}, mockStdin) + + // Verify user was removed + db2 := testutils.InitDB(tmpDB) + defer func() { + sqlDB2, _ := db2.DB() + sqlDB2.Close() + }() + + var count int64 + testutils.MustExec(t, db2.Model(&database.User{}).Count(&count), "counting users") + assert.Equal(t, count, int64(0), "should have 0 users") +} + +func TestUserResetPasswordCmd(t *testing.T) { + tmpDB := t.TempDir() + "/test.db" + + // Create a user first + db := testutils.InitDB(tmpDB) + user := testutils.SetupUserData(db) + account := testutils.SetupAccountData(db, user, "test@example.com", "oldpassword123") + oldPasswordHash := account.Password.String + sqlDB, _ := db.DB() + sqlDB.Close() + + // Reset password + userResetPasswordCmd([]string{"--dbPath", tmpDB, "--email", "test@example.com", "--password", "newpassword123"}) + + // Verify password was changed + db2 := testutils.InitDB(tmpDB) + defer func() { + sqlDB2, _ := db2.DB() + sqlDB2.Close() + }() + + var updatedAccount database.Account + testutils.MustExec(t, db2.Where("email = ?", "test@example.com").First(&updatedAccount), "finding account") + + // Verify password hash changed + assert.Equal(t, updatedAccount.Password.String != oldPasswordHash, true, "password hash should be different") + assert.Equal(t, len(updatedAccount.Password.String) > 0, true, "password should be set") + + // Verify new password works + err := bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("newpassword123")) + assert.Equal(t, err, nil, "new password should match") + + // Verify old password doesn't work + err = bcrypt.CompareHashAndPassword([]byte(updatedAccount.Password.String), []byte("oldpassword123")) + assert.Equal(t, err != nil, true, "old password should not match") +} diff --git a/pkg/server/cmd/version.go b/pkg/server/cmd/version.go new file mode 100644 index 00000000..99c68429 --- /dev/null +++ b/pkg/server/cmd/version.go @@ -0,0 +1,29 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package cmd + +import ( + "fmt" + + "github.com/dnote/dnote/pkg/server/buildinfo" +) + +func versionCmd() { + fmt.Printf("dnote-server-%s\n", buildinfo.Version) +} diff --git a/pkg/server/controllers/users.go b/pkg/server/controllers/users.go index 67baca6a..b1945d8f 100644 --- a/pkg/server/controllers/users.go +++ b/pkg/server/controllers/users.go @@ -396,27 +396,21 @@ func (u *Users) PasswordReset(w http.ResponseWriter, r *http.Request) { return } - tx := u.app.DB.Begin() - - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(params.Password), bcrypt.DefaultCost) - if err != nil { - tx.Rollback() - handleHTMLError(w, r, err, "hashing password", u.PasswordResetConfirmView, vd) - return - } - var account database.Account if err := u.app.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil { - tx.Rollback() handleHTMLError(w, r, err, "finding user", u.PasswordResetConfirmView, vd) return } - if err := tx.Model(&account).Update("password", string(hashedPassword)).Error; err != nil { + tx := u.app.DB.Begin() + + // Update the password + if err := app.UpdateAccountPassword(tx, &account, params.Password); err != nil { tx.Rollback() handleHTMLError(w, r, err, "updating password", u.PasswordResetConfirmView, vd) return } + if err := tx.Model(&token).Update("used_at", time.Now()).Error; err != nil { tx.Rollback() handleHTMLError(w, r, err, "updating password reset token", u.PasswordResetConfirmView, vd) @@ -514,18 +508,7 @@ func (u *Users) PasswordUpdate(w http.ResponseWriter, r *http.Request) { return } - if err := validatePassword(form.NewPassword); err != nil { - handleHTMLError(w, r, err, "invalid password", u.SettingView, vd) - return - } - - hashedNewPassword, err := bcrypt.GenerateFromPassword([]byte(form.NewPassword), bcrypt.DefaultCost) - if err != nil { - handleHTMLError(w, r, err, "hashing password", u.SettingView, vd) - return - } - - if err := u.app.DB.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil { + if err := app.UpdateAccountPassword(u.app.DB, &account, form.NewPassword); err != nil { handleHTMLError(w, r, err, "updating password", u.SettingView, vd) return } @@ -537,14 +520,6 @@ func (u *Users) PasswordUpdate(w http.ResponseWriter, r *http.Request) { views.RedirectAlert(w, r, "/", http.StatusFound, alert) } -func validatePassword(password string) error { - if len(password) < 8 { - return app.ErrPasswordTooShort - } - - return nil -} - type updateProfileForm struct { Email string `schema:"email"` Password string `schema:"password"` diff --git a/pkg/server/log/log_test.go b/pkg/server/log/log_test.go new file mode 100644 index 00000000..3df63b3a --- /dev/null +++ b/pkg/server/log/log_test.go @@ -0,0 +1,38 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package log + +import ( + "testing" +) + +func TestSetLevel(t *testing.T) { + // Reset to default after test + defer SetLevel(LevelInfo) + + SetLevel(LevelDebug) + if currentLevel != LevelDebug { + t.Errorf("Expected level %s, got %s", LevelDebug, currentLevel) + } + + SetLevel(LevelError) + if currentLevel != LevelError { + t.Errorf("Expected level %s, got %s", LevelError, currentLevel) + } +} diff --git a/pkg/server/main.go b/pkg/server/main.go index e8fa1aa8..701912ea 100644 --- a/pkg/server/main.go +++ b/pkg/server/main.go @@ -19,156 +19,9 @@ package main import ( - "flag" - "fmt" - "net/http" - "os" - - "github.com/dnote/dnote/pkg/clock" - "github.com/dnote/dnote/pkg/server/app" - "github.com/dnote/dnote/pkg/server/buildinfo" - "github.com/dnote/dnote/pkg/server/config" - "github.com/dnote/dnote/pkg/server/controllers" - "github.com/dnote/dnote/pkg/server/database" - "github.com/dnote/dnote/pkg/server/log" - "github.com/dnote/dnote/pkg/server/mailer" - "github.com/pkg/errors" - "gorm.io/gorm" + "github.com/dnote/dnote/pkg/server/cmd" ) -func initDB(dbPath string) *gorm.DB { - db := database.Open(dbPath) - database.InitSchema(db) - database.Migrate(db) - - return db -} - -func initApp(cfg config.Config) app.App { - db := initDB(cfg.DBPath) - - emailBackend, err := mailer.NewDefaultBackend(cfg.IsProd()) - if err != nil { - emailBackend = &mailer.DefaultBackend{Enabled: false} - } else { - log.Info("Email backend configured") - } - - return app.App{ - DB: db, - Clock: clock.New(), - EmailTemplates: mailer.NewTemplates(), - EmailBackend: emailBackend, - HTTP500Page: cfg.HTTP500Page, - AppEnv: cfg.AppEnv, - WebURL: cfg.WebURL, - DisableRegistration: cfg.DisableRegistration, - Port: cfg.Port, - DBPath: cfg.DBPath, - AssetBaseURL: cfg.AssetBaseURL, - } -} - -func startCmd(args []string) { - startFlags := flag.NewFlagSet("start", flag.ExitOnError) - startFlags.Usage = func() { - fmt.Printf(`Usage: - dnote-server start [flags] - -Flags: -`) - startFlags.PrintDefaults() - } - - appEnv := startFlags.String("appEnv", "", "Application environment (env: APP_ENV, default: PRODUCTION)") - port := startFlags.String("port", "", "Server port (env: PORT, default: 3001)") - webURL := startFlags.String("webUrl", "", "Full URL to server without trailing slash (env: WebURL, default: http://localhost:3001)") - dbPath := startFlags.String("dbPath", "", "Path to SQLite database file (env: DBPath, default: $XDG_DATA_HOME/dnote/server.db)") - disableRegistration := startFlags.Bool("disableRegistration", false, "Disable user registration (env: DisableRegistration, default: false)") - logLevel := startFlags.String("logLevel", "", "Log level: debug, info, warn, or error (env: LOG_LEVEL, default: info)") - - startFlags.Parse(args) - - cfg, err := config.New(config.Params{ - AppEnv: *appEnv, - Port: *port, - WebURL: *webURL, - DBPath: *dbPath, - DisableRegistration: *disableRegistration, - LogLevel: *logLevel, - }) - if err != nil { - fmt.Printf("Error: %s\n\n", err) - startFlags.Usage() - os.Exit(1) - } - - // Set log level - log.SetLevel(cfg.LogLevel) - - app := initApp(cfg) - defer func() { - sqlDB, err := app.DB.DB() - if err == nil { - sqlDB.Close() - } - }() - - ctl := controllers.New(&app) - rc := controllers.RouteConfig{ - WebRoutes: controllers.NewWebRoutes(&app, ctl), - APIRoutes: controllers.NewAPIRoutes(&app, ctl), - Controllers: ctl, - } - - r, err := controllers.NewRouter(&app, rc) - if err != nil { - panic(errors.Wrap(err, "initializing router")) - } - - log.WithFields(log.Fields{ - "version": buildinfo.Version, - "port": cfg.Port, - }).Info("Dnote server starting") - - if err := http.ListenAndServe(fmt.Sprintf(":%s", cfg.Port), r); err != nil { - log.ErrorWrap(err, "server failed") - os.Exit(1) - } -} - -func versionCmd() { - fmt.Printf("dnote-server-%s\n", buildinfo.Version) -} - -func rootCmd() { - fmt.Printf(`Dnote server - a simple command line notebook - -Usage: - dnote-server [command] [flags] - -Available commands: - start: Start the server (use 'dnote-server start --help' for flags) - version: Print the version -`) -} - func main() { - if len(os.Args) < 2 { - rootCmd() - return - } - - cmd := os.Args[1] - - switch cmd { - case "start": - startCmd(os.Args[2:]) - case "version": - versionCmd() - default: - fmt.Printf("Unknown command %s\n", cmd) - rootCmd() - os.Exit(1) - } + cmd.Execute() }