From 1ca640812e2b1060bca0c9d9c9dccc549e8d28ca Mon Sep 17 00:00:00 2001 From: Sung Date: Sun, 12 Oct 2025 13:20:27 -0700 Subject: [PATCH] Allow to specify CLI db path as a flag --- pkg/cli/cmd/root/root.go | 7 +++++ pkg/cli/infra/init.go | 15 ++++++--- pkg/cli/main.go | 6 ++-- pkg/cli/main_test.go | 68 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 7 deletions(-) diff --git a/pkg/cli/cmd/root/root.go b/pkg/cli/cmd/root/root.go index 750629e4..b47a3e3c 100644 --- a/pkg/cli/cmd/root/root.go +++ b/pkg/cli/cmd/root/root.go @@ -23,6 +23,7 @@ import ( ) var apiEndpointFlag string +var dbPathFlag string var root = &cobra.Command{ Use: "dnote", @@ -36,6 +37,7 @@ var root = &cobra.Command{ func init() { root.PersistentFlags().StringVar(&apiEndpointFlag, "api-endpoint", "", "the API endpoint to connect to (defaults to value in config)") + root.PersistentFlags().StringVar(&dbPathFlag, "dbPath", "", "the path to the database file (defaults to standard location)") } // GetRoot returns the root command @@ -48,6 +50,11 @@ func GetAPIEndpointFlag() string { return apiEndpointFlag } +// GetDBPathFlag returns the value of the --dbpath flag +func GetDBPathFlag() string { + return dbPathFlag +} + // Register adds a new command func Register(cmd *cobra.Command) { root.AddCommand(cmd) diff --git a/pkg/cli/infra/init.go b/pkg/cli/infra/init.go index 532f1f1e..16477dd9 100644 --- a/pkg/cli/infra/init.go +++ b/pkg/cli/infra/init.go @@ -59,7 +59,12 @@ func checkLegacyDBPath() (string, bool) { return "", false } -func getDBPath(paths context.Paths) string { +func getDBPath(paths context.Paths, customPath string) string { + // If custom path is provided, use it + if customPath != "" { + return customPath + } + legacyDnoteDir, ok := checkLegacyDBPath() if ok { return fmt.Sprintf("%s/%s", legacyDnoteDir, consts.DnoteDBFileName) @@ -71,7 +76,7 @@ func getDBPath(paths context.Paths) string { // newBaseCtx creates a minimal context with paths and database connection. // This base context is used for file and database initialization before // being enriched with config values by setupCtx. -func newBaseCtx(versionTag string) (context.DnoteCtx, error) { +func newBaseCtx(versionTag, customDBPath string) (context.DnoteCtx, error) { dnoteDir := getLegacyDnotePath(dirs.Home) paths := context.Paths{ Home: dirs.Home, @@ -81,7 +86,7 @@ func newBaseCtx(versionTag string) (context.DnoteCtx, error) { LegacyDnote: dnoteDir, } - dbPath := getDBPath(paths) + dbPath := getDBPath(paths, customDBPath) db, err := database.Open(dbPath) if err != nil { @@ -98,8 +103,8 @@ func newBaseCtx(versionTag string) (context.DnoteCtx, error) { } // Init initializes the Dnote environment and returns a new dnote context -func Init(versionTag, apiEndpoint string) (*context.DnoteCtx, error) { - ctx, err := newBaseCtx(versionTag) +func Init(versionTag, apiEndpoint, customDBPath string) (*context.DnoteCtx, error) { + ctx, err := newBaseCtx(versionTag, customDBPath) if err != nil { return nil, errors.Wrap(err, "initializing a context") } diff --git a/pkg/cli/main.go b/pkg/cli/main.go index 31c81dd5..db727fca 100644 --- a/pkg/cli/main.go +++ b/pkg/cli/main.go @@ -46,7 +46,7 @@ var apiEndpoint string var versionTag = "master" func main() { - // Parse flags early to check if --api-endpoint was provided + // Parse flags early to check if --api-endpoint and --dbpath were provided root.GetRoot().ParseFlags(os.Args[1:]) // Use flag value if provided, otherwise use ldflags value @@ -55,7 +55,9 @@ func main() { endpoint = flagValue } - ctx, err := infra.Init(versionTag, endpoint) + dbPath := root.GetDBPathFlag() + + ctx, err := infra.Init(versionTag, endpoint, dbPath) if err != nil { panic(errors.Wrap(err, "initializing context")) } diff --git a/pkg/cli/main_test.go b/pkg/cli/main_test.go index 151d4cbf..63a95d87 100644 --- a/pkg/cli/main_test.go +++ b/pkg/cli/main_test.go @@ -501,3 +501,71 @@ func TestRemoveBook(t *testing.T) { }) } } + +func TestDBPathFlag(t *testing.T) { + // Helper function to verify database contents + verifyDatabase := func(t *testing.T, dbPath, expectedBook, expectedNote string) *database.DB { + ok, err := utils.FileExists(dbPath) + if err != nil { + t.Fatal(errors.Wrapf(err, "checking if custom db exists at %s", dbPath)) + } + if !ok { + t.Errorf("custom database was not created at %s", dbPath) + } + + db, err := database.Open(dbPath) + if err != nil { + t.Fatal(errors.Wrapf(err, "opening db at %s", dbPath)) + } + + var noteCount, bookCount int + database.MustScan(t, "counting books", db.QueryRow("SELECT count(*) FROM books"), &bookCount) + database.MustScan(t, "counting notes", db.QueryRow("SELECT count(*) FROM notes"), ¬eCount) + + assert.Equalf(t, bookCount, 1, fmt.Sprintf("%s book count mismatch", dbPath)) + assert.Equalf(t, noteCount, 1, fmt.Sprintf("%s note count mismatch", dbPath)) + + var book database.Book + database.MustScan(t, "getting book", db.QueryRow("SELECT label FROM books"), &book.Label) + assert.Equalf(t, book.Label, expectedBook, fmt.Sprintf("%s book label mismatch", dbPath)) + + var note database.Note + database.MustScan(t, "getting note", db.QueryRow("SELECT body FROM notes"), ¬e.Body) + assert.Equalf(t, note.Body, expectedNote, fmt.Sprintf("%s note body mismatch", dbPath)) + + return db + } + + // Setup - use two different custom database paths + customDBPath1 := "./tmp/custom-test1.db" + customDBPath2 := "./tmp/custom-test2.db" + defer testutils.RemoveDir(t, "./tmp") + + customOpts := testutils.RunDnoteCmdOptions{ + Env: []string{ + fmt.Sprintf("XDG_CONFIG_HOME=%s", testDir), + fmt.Sprintf("XDG_DATA_HOME=%s", testDir), + fmt.Sprintf("XDG_CACHE_HOME=%s", testDir), + }, + } + + // Execute - add different notes to each database + testutils.RunDnoteCmd(t, customOpts, binaryName, "--dbPath", customDBPath1, "add", "db1-book", "-c", "content in db1") + testutils.RunDnoteCmd(t, customOpts, binaryName, "--dbPath", customDBPath2, "add", "db2-book", "-c", "content in db2") + + // Test both databases + db1 := verifyDatabase(t, customDBPath1, "db1-book", "content in db1") + defer db1.Close() + + db2 := verifyDatabase(t, customDBPath2, "db2-book", "content in db2") + defer db2.Close() + + // Verify that the databases are independent + var db1HasDB2Book int + db1.QueryRow("SELECT count(*) FROM books WHERE label = ?", "db2-book").Scan(&db1HasDB2Book) + assert.Equal(t, db1HasDB2Book, 0, "db1 should not have db2's book") + + var db2HasDB1Book int + db2.QueryRow("SELECT count(*) FROM books WHERE label = ?", "db1-book").Scan(&db2HasDB1Book) + assert.Equal(t, db2HasDB1Book, 0, "db2 should not have db1's book") +}