diff --git a/pkg/e2e/server_test.go b/pkg/e2e/server_test.go index cfe6a711..39da46de 100644 --- a/pkg/e2e/server_test.go +++ b/pkg/e2e/server_test.go @@ -50,7 +50,7 @@ func TestServerStart(t *testing.T) { port := "13456" // Use different port to avoid conflicts with main test server // Start server in background - cmd := exec.Command(testServerBinary, "start", "-port", port) + cmd := exec.Command(testServerBinary, "start", "--port", port) cmd.Env = append(os.Environ(), "DBPath="+tmpDB, "WebURL=http://localhost:"+port, @@ -143,11 +143,11 @@ func TestServerStartHelp(t *testing.T) { outputStr := string(output) assert.Equal(t, strings.Contains(outputStr, "dnote-server start [flags]"), true, "output should contain usage") - assert.Equal(t, strings.Contains(outputStr, "-appEnv"), true, "output should contain appEnv flag") - assert.Equal(t, strings.Contains(outputStr, "-port"), true, "output should contain port flag") - assert.Equal(t, strings.Contains(outputStr, "-webUrl"), true, "output should contain webUrl flag") - assert.Equal(t, strings.Contains(outputStr, "-dbPath"), true, "output should contain dbPath flag") - assert.Equal(t, strings.Contains(outputStr, "-disableRegistration"), true, "output should contain disableRegistration flag") + assert.Equal(t, strings.Contains(outputStr, "--appEnv"), true, "output should contain appEnv flag") + assert.Equal(t, strings.Contains(outputStr, "--port"), true, "output should contain port flag") + assert.Equal(t, strings.Contains(outputStr, "--webUrl"), true, "output should contain webUrl flag") + assert.Equal(t, strings.Contains(outputStr, "--dbPath"), true, "output should contain dbPath flag") + assert.Equal(t, strings.Contains(outputStr, "--disableRegistration"), true, "output should contain disableRegistration flag") } func TestServerStartInvalidConfig(t *testing.T) { @@ -166,7 +166,7 @@ func TestServerStartInvalidConfig(t *testing.T) { assert.Equal(t, strings.Contains(outputStr, "Error:"), true, "output should contain error message") assert.Equal(t, strings.Contains(outputStr, "Invalid WebURL"), true, "output should mention invalid WebURL") assert.Equal(t, strings.Contains(outputStr, "dnote-server start [flags]"), true, "output should show usage") - assert.Equal(t, strings.Contains(outputStr, "-webUrl"), true, "output should show flags") + assert.Equal(t, strings.Contains(outputStr, "--webUrl"), true, "output should show flags") } func TestServerUnknownCommand(t *testing.T) { @@ -321,3 +321,19 @@ func TestServerUserRemove(t *testing.T) { db.Table("users").Count(&count) assert.Equal(t, count, int64(0), "should have 0 users after removal") } + +func TestServerUserCreateHelp(t *testing.T) { + cmd := exec.Command(testServerBinary, "user", "create", "--help") + output, err := cmd.CombinedOutput() + + if err != nil { + t.Fatalf("help command failed: %v\nOutput: %s", err, output) + } + + outputStr := string(output) + + // Verify help shows double-dash flags for consistency with CLI + assert.Equal(t, strings.Contains(outputStr, "--email"), true, "help should show --email (double dash)") + assert.Equal(t, strings.Contains(outputStr, "--password"), true, "help should show --password (double dash)") + assert.Equal(t, strings.Contains(outputStr, "--dbPath"), true, "help should show --dbPath (double dash)") +} diff --git a/pkg/server/cmd/helpers.go b/pkg/server/cmd/helpers.go index a22c8721..eb58c318 100644 --- a/pkg/server/cmd/helpers.go +++ b/pkg/server/cmd/helpers.go @@ -65,6 +65,29 @@ func initApp(cfg config.Config) app.App { } } +// printFlags prints flags with -- prefix for consistency with CLI +func printFlags(fs *flag.FlagSet) { + fs.VisitAll(func(f *flag.Flag) { + fmt.Printf(" --%s", f.Name) + + // Print type hint for non-boolean flags + name, usage := flag.UnquoteUsage(f) + if name != "" { + fmt.Printf(" %s", name) + } + fmt.Println() + + // Print usage description with indentation + if usage != "" { + fmt.Printf(" \t%s", usage) + if f.DefValue != "" && f.DefValue != "false" { + fmt.Printf(" (default: %s)", f.DefValue) + } + fmt.Println() + } + }) +} + // setupFlagSet creates a FlagSet with standard usage format func setupFlagSet(name, usageCmd string) *flag.FlagSet { fs := flag.NewFlagSet(name, flag.ExitOnError) @@ -74,7 +97,7 @@ func setupFlagSet(name, usageCmd string) *flag.FlagSet { Flags: `, usageCmd) - fs.PrintDefaults() + printFlags(fs) } return fs } diff --git a/pkg/server/middleware/auth.go b/pkg/server/middleware/auth.go index 984079e4..f74d1efa 100644 --- a/pkg/server/middleware/auth.go +++ b/pkg/server/middleware/auth.go @@ -101,7 +101,11 @@ func WithAccount(db *gorm.DB, next http.HandlerFunc) http.HandlerFunc { user := context.User(r.Context()) var account database.Account - if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil { + err := db.Where("user_id = ?", user.ID).First(&account).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + DoError(w, "account not found", err, http.StatusForbidden) + return + } else if err != nil { DoError(w, "finding account", err, http.StatusInternalServerError) return } diff --git a/pkg/server/middleware/auth_test.go b/pkg/server/middleware/auth_test.go index 1485befa..c0d94096 100644 --- a/pkg/server/middleware/auth_test.go +++ b/pkg/server/middleware/auth_test.go @@ -233,3 +233,37 @@ func TestTokenAuth(t *testing.T) { assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") }) } + +func TestWithAccount(t *testing.T) { + db := testutils.InitMemoryDB(t) + + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + t.Run("user with account", func(t *testing.T) { + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + + server := httptest.NewServer(Auth(db, handler, nil)) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/", "") + res := testutils.HTTPAuthDo(t, db, req, user) + + assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") + }) + + t.Run("user without account", func(t *testing.T) { + user := testutils.SetupUserData(db) + // Note: not creating account for this user + + server := httptest.NewServer(Auth(db, handler, nil)) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/", "") + res := testutils.HTTPAuthDo(t, db, req, user) + + assert.Equal(t, res.StatusCode, http.StatusForbidden, "status code mismatch") + }) +}