From 73526a943ceabf6af35c11d8fb3051b28eb38fe0 Mon Sep 17 00:00:00 2001 From: Sung Won Cho Date: Sun, 31 Mar 2019 16:23:46 +1100 Subject: [PATCH] Encryption (#165) * Implement login and logout * Add encrypt util * Use v2 * Abstract common interface between db and tx * Fix test * Check login * Fix test * Fix login * Fix path * Improve test * Fix output --- COMMANDS.md | 13 ++- Gopkg.lock | 20 +++- Gopkg.toml | 2 +- client/client.go | 206 ++++++++++++++++++++++++++++++------- cmd/login/login.go | 92 ++++++++++++----- cmd/logout/logout.go | 82 +++++++++++++++ cmd/sync/sync.go | 136 ++++++++++++++----------- cmd/sync/sync_test.go | 158 ++++++++++++++++++++++------ core/core.go | 36 ++++++- core/models.go | 35 ++++--- core/operations.go | 85 ++++++++++++++++ core/operations_test.go | 221 ++++++++++++++++++++++++++++++++++++++++ crypt/utils.go | 107 +++++++++++++++++++ crypt/utils_test.go | 100 ++++++++++++++++++ infra/main.go | 62 +++++++++-- infra/sql.go | 120 ++++++++++++++++++++++ log/log.go | 21 +++- main.go | 7 ++ migrate/migrate_test.go | 31 +++--- migrate/migrations.go | 34 +++---- scripts/test.sh | 9 +- testutils/main.go | 29 ++++-- utils/utils.go | 83 +++++++++++++-- 23 files changed, 1445 insertions(+), 244 deletions(-) create mode 100644 cmd/logout/logout.go create mode 100644 core/operations.go create mode 100644 core/operations_test.go create mode 100644 crypt/utils.go create mode 100644 crypt/utils_test.go create mode 100644 infra/sql.go diff --git a/COMMANDS.md b/COMMANDS.md index 10408d7e..aa4994c3 100644 --- a/COMMANDS.md +++ b/COMMANDS.md @@ -5,8 +5,9 @@ - [edit](#dnote-edit) - [remove](#dnote-remove) - [find](#dnote-find) -- [login](#dnote-login) - [sync](#dnote-sync) +- [login](#dnote-login) +- [logout](#dnote-logout) ## dnote add @@ -87,7 +88,7 @@ dnote find "merge sort" -b algorithm ## dnote sync -_Dnote Cloud only_ +_Dnote Pro only_ _alias: s_ @@ -95,6 +96,12 @@ Sync notes with Dnote cloud ## dnote login -_Dnote Cloud only_ +_Dnote Pro only_ Start a login prompt + +## dnote logout + +_Dnote Pro only_ + +Log out of Dnote diff --git a/Gopkg.lock b/Gopkg.lock index efef5ea8..4649541f 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -97,11 +97,26 @@ revision = "298182f68c66c05229eb03ac171abe6e309ee79a" version = "v1.0.3" +[[projects]] + branch = "master" + digest = "1:d0f4eb7abce3fbd3f0dcbbc03ffe18464846afd34c815928d2ae11c1e5aded04" + name = "golang.org/x/crypto" + packages = [ + "hkdf", + "pbkdf2", + "ssh/terminal", + ] + pruneopts = "" + revision = "ffb98f73852f696ea2bb21a617a5c4b3e067a439" + [[projects]] branch = "master" digest = "1:7e3b61f51ebcb58b3894928ed7c63aae68820dec1dd57166e5d6e65ef2868f40" name = "golang.org/x/sys" - packages = ["unix"] + packages = [ + "unix", + "windows", + ] pruneopts = "" revision = "b90733256f2e882e81d52f9126de08df5615afd9" @@ -124,6 +139,9 @@ "github.com/pkg/errors", "github.com/satori/go.uuid", "github.com/spf13/cobra", + "golang.org/x/crypto/hkdf", + "golang.org/x/crypto/pbkdf2", + "golang.org/x/crypto/ssh/terminal", "gopkg.in/yaml.v2", ] solver-name = "gps-cdcl" diff --git a/Gopkg.toml b/Gopkg.toml index 206df2ef..e5e6ef80 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -30,7 +30,7 @@ [[constraint]] name = "github.com/spf13/cobra" - version = "0.0.1" + version = "0.0.3" [[constraint]] branch = "v2" diff --git a/client/client.go b/client/client.go index cc8311d9..7d99f9d2 100644 --- a/client/client.go +++ b/client/client.go @@ -12,11 +12,15 @@ import ( "strings" "time" + "github.com/dnote/cli/crypt" "github.com/dnote/cli/infra" "github.com/dnote/cli/utils" "github.com/pkg/errors" ) +// ErrInvalidLogin is an error for invalid credentials for login +var ErrInvalidLogin = errors.New("wrong credentials") + // GetSyncStateResp is the response get sync state endpoint type GetSyncStateResp struct { FullSyncBefore int `json:"full_sync_before"` @@ -25,10 +29,11 @@ type GetSyncStateResp struct { } // GetSyncState gets the sync state response from the server -func GetSyncState(apiKey string, ctx infra.DnoteCtx) (GetSyncStateResp, error) { +func GetSyncState(ctx infra.DnoteCtx) (GetSyncStateResp, error) { var ret GetSyncStateResp - res, err := utils.DoAuthorizedReq(ctx, apiKey, "GET", "/v1/sync/state", "") + hc := http.Client{} + res, err := utils.DoAuthorizedReq(ctx, hc, "GET", "/v1/sync/state", "") if err != nil { return ret, errors.Wrap(err, "constructing http request") } @@ -89,13 +94,14 @@ type GetSyncFragmentResp struct { } // GetSyncFragment gets a sync fragment response from the server -func GetSyncFragment(ctx infra.DnoteCtx, apiKey string, afterUSN int) (GetSyncFragmentResp, error) { +func GetSyncFragment(ctx infra.DnoteCtx, afterUSN int) (GetSyncFragmentResp, error) { v := url.Values{} v.Set("after_usn", strconv.Itoa(afterUSN)) queryStr := v.Encode() path := fmt.Sprintf("/v1/sync/fragment?%s", queryStr) - res, err := utils.DoAuthorizedReq(ctx, apiKey, "GET", path, "") + hc := http.Client{} + res, err := utils.DoAuthorizedReq(ctx, hc, "GET", path, "") body, err := ioutil.ReadAll(res.Body) if err != nil { @@ -148,25 +154,31 @@ func checkRespErr(res *http.Response) (bool, string, error) { } // CreateBook creates a new book in the server -func CreateBook(ctx infra.DnoteCtx, apiKey, label string) (CreateBookResp, error) { +func CreateBook(ctx infra.DnoteCtx, label string) (CreateBookResp, error) { + encLabel, err := crypt.AesGcmEncrypt(ctx.CipherKey, []byte(label)) + if err != nil { + return CreateBookResp{}, errors.Wrap(err, "encrypting the label") + } + payload := CreateBookPayload{ - Name: label, + Name: encLabel, } b, err := json.Marshal(payload) if err != nil { return CreateBookResp{}, errors.Wrap(err, "marshaling payload") } - res, err := utils.DoAuthorizedReq(ctx, apiKey, "POST", "/v1/books", string(b)) + hc := http.Client{} + res, err := utils.DoAuthorizedReq(ctx, hc, "POST", "/v2/books", string(b)) if err != nil { return CreateBookResp{}, errors.Wrap(err, "posting a book to the server") } - ok, message, err := checkRespErr(res) + hasErr, message, err := checkRespErr(res) if err != nil { return CreateBookResp{}, errors.Wrap(err, "checking repsonse error") } - if ok { + if hasErr { return CreateBookResp{}, errors.New(message) } @@ -188,26 +200,32 @@ type UpdateBookResp struct { } // UpdateBook updates a book in the server -func UpdateBook(ctx infra.DnoteCtx, apiKey, label, uuid string) (UpdateBookResp, error) { +func UpdateBook(ctx infra.DnoteCtx, label, uuid string) (UpdateBookResp, error) { + encName, err := crypt.AesGcmEncrypt(ctx.CipherKey, []byte(label)) + if err != nil { + return UpdateBookResp{}, errors.Wrap(err, "encrypting the content") + } + payload := updateBookPayload{ - Name: &label, + Name: &encName, } b, err := json.Marshal(payload) if err != nil { return UpdateBookResp{}, errors.Wrap(err, "marshaling payload") } + hc := http.Client{} endpoint := fmt.Sprintf("/v1/books/%s", uuid) - res, err := utils.DoAuthorizedReq(ctx, apiKey, "PATCH", endpoint, string(b)) + res, err := utils.DoAuthorizedReq(ctx, hc, "PATCH", endpoint, string(b)) if err != nil { return UpdateBookResp{}, errors.Wrap(err, "posting a book to the server") } - ok, message, err := checkRespErr(res) + hasErr, message, err := checkRespErr(res) if err != nil { return UpdateBookResp{}, errors.Wrap(err, "checking repsonse error") } - if ok { + if hasErr { return UpdateBookResp{}, errors.New(message) } @@ -226,18 +244,19 @@ type DeleteBookResp struct { } // DeleteBook deletes a book in the server -func DeleteBook(ctx infra.DnoteCtx, apiKey, uuid string) (DeleteBookResp, error) { +func DeleteBook(ctx infra.DnoteCtx, uuid string) (DeleteBookResp, error) { + hc := http.Client{} endpoint := fmt.Sprintf("/v1/books/%s", uuid) - res, err := utils.DoAuthorizedReq(ctx, apiKey, "DELETE", endpoint, "") + res, err := utils.DoAuthorizedReq(ctx, hc, "DELETE", endpoint, "") if err != nil { return DeleteBookResp{}, errors.Wrap(err, "deleting a book in the server") } - ok, message, err := checkRespErr(res) + hasErr, message, err := checkRespErr(res) if err != nil { return DeleteBookResp{}, errors.Wrap(err, "checking repsonse error") } - if ok { + if hasErr { return DeleteBookResp{}, errors.New(message) } @@ -283,26 +302,32 @@ type RespNote struct { } // CreateNote creates a note in the server -func CreateNote(ctx infra.DnoteCtx, apiKey, bookUUID, content string) (CreateNoteResp, error) { +func CreateNote(ctx infra.DnoteCtx, bookUUID, content string) (CreateNoteResp, error) { + encBody, err := crypt.AesGcmEncrypt(ctx.CipherKey, []byte(content)) + if err != nil { + return CreateNoteResp{}, errors.Wrap(err, "encrypting the content") + } + payload := CreateNotePayload{ BookUUID: bookUUID, - Body: content, + Body: encBody, } b, err := json.Marshal(payload) if err != nil { return CreateNoteResp{}, errors.Wrap(err, "marshaling payload") } - res, err := utils.DoAuthorizedReq(ctx, apiKey, "POST", "/v1/notes", string(b)) + hc := http.Client{} + res, err := utils.DoAuthorizedReq(ctx, hc, "POST", "/v2/notes", string(b)) if err != nil { return CreateNoteResp{}, errors.Wrap(err, "posting a book to the server") } - ok, message, err := checkRespErr(res) + hasErr, message, err := checkRespErr(res) if err != nil { return CreateNoteResp{}, errors.Wrap(err, "checking repsonse error") } - if ok { + if hasErr { return CreateNoteResp{}, errors.New(message) } @@ -327,10 +352,15 @@ type UpdateNoteResp struct { } // UpdateNote updates a note in the server -func UpdateNote(ctx infra.DnoteCtx, apiKey, uuid, bookUUID, content string, public bool) (UpdateNoteResp, error) { +func UpdateNote(ctx infra.DnoteCtx, uuid, bookUUID, content string, public bool) (UpdateNoteResp, error) { + encBody, err := crypt.AesGcmEncrypt(ctx.CipherKey, []byte(content)) + if err != nil { + return UpdateNoteResp{}, errors.Wrap(err, "encrypting the content") + } + payload := updateNotePayload{ BookUUID: &bookUUID, - Body: &content, + Body: &encBody, Public: &public, } b, err := json.Marshal(payload) @@ -338,17 +368,18 @@ func UpdateNote(ctx infra.DnoteCtx, apiKey, uuid, bookUUID, content string, publ return UpdateNoteResp{}, errors.Wrap(err, "marshaling payload") } + hc := http.Client{} endpoint := fmt.Sprintf("/v1/notes/%s", uuid) - res, err := utils.DoAuthorizedReq(ctx, apiKey, "PATCH", endpoint, string(b)) + res, err := utils.DoAuthorizedReq(ctx, hc, "PATCH", endpoint, string(b)) if err != nil { return UpdateNoteResp{}, errors.Wrap(err, "patching a note to the server") } - ok, message, err := checkRespErr(res) + hasErr, message, err := checkRespErr(res) if err != nil { return UpdateNoteResp{}, errors.Wrap(err, "checking repsonse error") } - if ok { + if hasErr { return UpdateNoteResp{}, errors.New(message) } @@ -367,18 +398,19 @@ type DeleteNoteResp struct { } // DeleteNote removes a note in the server -func DeleteNote(ctx infra.DnoteCtx, apiKey, uuid string) (DeleteNoteResp, error) { +func DeleteNote(ctx infra.DnoteCtx, uuid string) (DeleteNoteResp, error) { + hc := http.Client{} endpoint := fmt.Sprintf("/v1/notes/%s", uuid) - res, err := utils.DoAuthorizedReq(ctx, apiKey, "DELETE", endpoint, "") + res, err := utils.DoAuthorizedReq(ctx, hc, "DELETE", endpoint, "") if err != nil { return DeleteNoteResp{}, errors.Wrap(err, "patching a note to the server") } - ok, message, err := checkRespErr(res) + hasErr, message, err := checkRespErr(res) if err != nil { return DeleteNoteResp{}, errors.Wrap(err, "checking repsonse error") } - if ok { + if hasErr { return DeleteNoteResp{}, errors.New(message) } @@ -397,17 +429,18 @@ type GetBooksResp []struct { } // GetBooks gets books from the server -func GetBooks(ctx infra.DnoteCtx, apiKey string) (GetBooksResp, error) { - res, err := utils.DoAuthorizedReq(ctx, apiKey, "GET", "/v1/books", "") +func GetBooks(ctx infra.DnoteCtx, sessionKey string) (GetBooksResp, error) { + hc := http.Client{} + res, err := utils.DoAuthorizedReq(ctx, hc, "GET", "/v1/books", "") if err != nil { return GetBooksResp{}, errors.Wrap(err, "making http request") } - ok, message, err := checkRespErr(res) + hasErr, message, err := checkRespErr(res) if err != nil { return GetBooksResp{}, errors.Wrap(err, "checking repsonse error") } - if ok { + if hasErr { return GetBooksResp{}, errors.New(message) } @@ -418,3 +451,104 @@ func GetBooks(ctx infra.DnoteCtx, apiKey string) (GetBooksResp, error) { return resp, nil } + +// PresigninResponse is a reponse from /v1/presignin endpoint +type PresigninResponse struct { + Iteration int `json:"iteration"` +} + +// GetPresignin gets presignin credentials +func GetPresignin(ctx infra.DnoteCtx, email string) (PresigninResponse, error) { + res, err := utils.DoReq(ctx, "GET", fmt.Sprintf("/v1/presignin?email=%s", email), "") + if err != nil { + return PresigninResponse{}, errors.Wrap(err, "making http request") + } + + hasErr, message, err := checkRespErr(res) + if err != nil { + return PresigninResponse{}, errors.Wrap(err, "checking repsonse error") + } + if hasErr { + return PresigninResponse{}, errors.New(message) + } + + var resp PresigninResponse + if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { + return PresigninResponse{}, errors.Wrap(err, "decoding payload") + } + + return resp, nil +} + +// SigninPayload is a payload for /v1/signin +type SigninPayload struct { + Email string `json:"email"` + AuthKey string `json:"auth_key"` +} + +// SigninResponse is a response from /v1/signin endpoint +type SigninResponse struct { + Key string `json:"key"` + ExpiresAt int64 `json:"expires_at"` + CipherKeyEnc string `json:"cipher_key_enc"` +} + +// Signin requests a session token +func Signin(ctx infra.DnoteCtx, email, authKey string) (SigninResponse, error) { + payload := SigninPayload{ + Email: email, + AuthKey: authKey, + } + b, err := json.Marshal(payload) + if err != nil { + return SigninResponse{}, errors.Wrap(err, "marshaling payload") + } + res, err := utils.DoReq(ctx, "POST", "/v1/signin", string(b)) + if err != nil { + return SigninResponse{}, errors.Wrap(err, "making http request") + } + + if res.StatusCode == http.StatusUnauthorized { + return SigninResponse{}, ErrInvalidLogin + } + + hasErr, message, err := checkRespErr(res) + if err != nil { + return SigninResponse{}, errors.Wrap(err, "checking repsonse error") + } + if hasErr { + return SigninResponse{}, errors.New(message) + } + + var resp SigninResponse + if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { + return SigninResponse{}, errors.Wrap(err, "decoding payload") + } + + return resp, nil +} + +// Signout deletes a user session on the server side +func Signout(ctx infra.DnoteCtx, sessionKey string) error { + hc := http.Client{ + // No need to follow redirect + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + res, err := utils.DoAuthorizedReq(ctx, hc, "POST", "/v1/signout", "") + if err != nil { + return errors.Wrap(err, "making http request") + } + + hasErr, message, err := checkRespErr(res) + if err != nil { + return errors.Wrap(err, "checking repsonse error") + } + if hasErr { + return errors.New(message) + } + + return nil +} diff --git a/cmd/login/login.go b/cmd/login/login.go index 77f0f060..c2bea8fe 100644 --- a/cmd/login/login.go +++ b/cmd/login/login.go @@ -1,11 +1,15 @@ package login import ( - "fmt" + "encoding/base64" + "strconv" + "github.com/dnote/cli/client" "github.com/dnote/cli/core" + "github.com/dnote/cli/crypt" "github.com/dnote/cli/infra" "github.com/dnote/cli/log" + "github.com/dnote/cli/utils" "github.com/pkg/errors" "github.com/spf13/cobra" ) @@ -25,36 +29,78 @@ func NewCmd(ctx infra.DnoteCtx) *cobra.Command { return cmd } +// Do dervies credentials on the client side and requests a session token from the server +func Do(ctx infra.DnoteCtx, email, password string) error { + presigninResp, err := client.GetPresignin(ctx, email) + if err != nil { + return errors.Wrap(err, "getting presiginin") + } + + masterKey, authKey, err := crypt.MakeKeys([]byte(password), []byte(email), presigninResp.Iteration) + if err != nil { + return errors.Wrap(err, "making keys") + } + + authKeyB64 := base64.StdEncoding.EncodeToString(authKey) + signinResp, err := client.Signin(ctx, email, authKeyB64) + if err != nil { + return errors.Wrap(err, "requesting session") + } + + cipherKeyDec, err := crypt.AesGcmDecrypt(masterKey, signinResp.CipherKeyEnc) + if err != nil { + return errors.Wrap(err, "decrypting cipher key") + } + + cipherKeyDecB64 := base64.StdEncoding.EncodeToString(cipherKeyDec) + + db := ctx.DB + tx, err := db.Begin() + if err != nil { + return errors.Wrap(err, "beginning a transaction") + } + + if err := core.UpsertSystem(tx, infra.SystemCipherKey, cipherKeyDecB64); err != nil { + return errors.Wrap(err, "saving enc key") + } + if err := core.UpsertSystem(tx, infra.SystemSessionKey, signinResp.Key); err != nil { + return errors.Wrap(err, "saving session key") + } + if err := core.UpsertSystem(tx, infra.SystemSessionKeyExpiry, strconv.FormatInt(signinResp.ExpiresAt, 10)); err != nil { + return errors.Wrap(err, "saving session key") + } + + tx.Commit() + + return nil +} + func newRun(ctx infra.DnoteCtx) core.RunEFunc { return func(cmd *cobra.Command, args []string) error { - log.Plain("\n") - log.Plain(" _( )_( )_\n") - log.Plain(" (_ _ _)\n") - log.Plain(" (_) (__)\n\n") - log.Plain("Welcome to Dnote Cloud :)\n\n") - log.Plain("A home for your engineering microlessons\n") - log.Plain("You can register at https://dnote.io/cloud\n\n") - log.Printf("API key: ") - - var apiKey string - fmt.Scanln(&apiKey) - - if apiKey == "" { - return errors.New("Empty API key") + var email, password string + if err := utils.PromptInput("email", &email); err != nil { + return errors.Wrap(err, "getting email input") + } + if email == "" { + return errors.New("Email is empty") } - config, err := core.ReadConfig(ctx) - if err != nil { - return err + if err := utils.PromptPassword("password", &password); err != nil { + return errors.Wrap(err, "getting password input") + } + if password == "" { + return errors.New("Password is empty") } - config.APIKey = apiKey - err = core.WriteConfig(ctx, config) - if err != nil { - return errors.Wrap(err, "Failed to write to config file") + err := Do(ctx, email, password) + if errors.Cause(err) == client.ErrInvalidLogin { + log.Error("wrong login\n") + return nil + } else if err != nil { + return errors.Wrap(err, "logging in") } - log.Success("configured\n") + log.Success("logged in\n") return nil } diff --git a/cmd/logout/logout.go b/cmd/logout/logout.go new file mode 100644 index 00000000..6a0eaf51 --- /dev/null +++ b/cmd/logout/logout.go @@ -0,0 +1,82 @@ +package logout + +import ( + "database/sql" + + "github.com/dnote/cli/client" + "github.com/dnote/cli/core" + "github.com/dnote/cli/infra" + "github.com/dnote/cli/log" + "github.com/pkg/errors" + "github.com/spf13/cobra" +) + +// ErrNotLoggedIn is an error for logging out when not logged in +var ErrNotLoggedIn = errors.New("not logged in") + +var example = ` + dnote logout` + +// NewCmd returns a new logout command +func NewCmd(ctx infra.DnoteCtx) *cobra.Command { + cmd := &cobra.Command{ + Use: "logout", + Short: "Logout from the server", + Example: example, + RunE: newRun(ctx), + } + + return cmd +} + +// Do performs logout +func Do(ctx infra.DnoteCtx) error { + db := ctx.DB + tx, err := db.Begin() + if err != nil { + return errors.Wrap(err, "beginning a transaction") + } + + var key string + err = core.GetSystem(tx, infra.SystemSessionKey, &key) + if errors.Cause(err) == sql.ErrNoRows { + return ErrNotLoggedIn + } else if err != nil { + return errors.Wrap(err, "getting session key") + } + + err = client.Signout(ctx, key) + if err != nil { + return errors.Wrap(err, "requesting logout") + } + + if err := core.DeleteSystem(tx, infra.SystemCipherKey); err != nil { + return errors.Wrap(err, "deleting enc key") + } + if err := core.DeleteSystem(tx, infra.SystemSessionKey); err != nil { + return errors.Wrap(err, "deleting session key") + } + if err := core.DeleteSystem(tx, infra.SystemSessionKeyExpiry); err != nil { + return errors.Wrap(err, "deleting session key expiry") + } + + tx.Commit() + + return nil +} + +func newRun(ctx infra.DnoteCtx) core.RunEFunc { + return func(cmd *cobra.Command, args []string) error { + err := Do(ctx) + if err == ErrNotLoggedIn { + log.Error("not logged in\n") + return nil + } else if err != nil { + return errors.Wrap(err, "logging out") + } + + log.Success("logged out\n") + + return nil + } +} diff --git a/cmd/sync/sync.go b/cmd/sync/sync.go index 6968e1c5..73b21827 100644 --- a/cmd/sync/sync.go +++ b/cmd/sync/sync.go @@ -6,6 +6,7 @@ import ( "github.com/dnote/cli/client" "github.com/dnote/cli/core" + "github.com/dnote/cli/crypt" "github.com/dnote/cli/infra" "github.com/dnote/cli/log" "github.com/dnote/cli/migrate" @@ -39,22 +40,20 @@ func NewCmd(ctx infra.DnoteCtx) *cobra.Command { return cmd } -func getLastSyncAt(tx *sql.Tx) (int, error) { +func getLastSyncAt(tx *infra.DB) (int, error) { var ret int - err := tx.QueryRow("SELECT value FROM system WHERE key = ?", infra.SystemLastSyncAt).Scan(&ret) - if err != nil { + if err := core.GetSystem(tx, infra.SystemLastSyncAt, &ret); err != nil { return ret, errors.Wrap(err, "querying last sync time") } return ret, nil } -func getLastMaxUSN(tx *sql.Tx) (int, error) { +func getLastMaxUSN(tx *infra.DB) (int, error) { var ret int - err := tx.QueryRow("SELECT value FROM system WHERE key = ?", infra.SystemLastMaxUSN).Scan(&ret) - if err != nil { + if err := core.GetSystem(tx, infra.SystemLastMaxUSN, &ret); err != nil { return ret, errors.Wrap(err, "querying last user max_usn") } @@ -75,7 +74,9 @@ func (l syncList) getLength() int { return len(l.Notes) + len(l.Books) + len(l.ExpungedNotes) + len(l.ExpungedBooks) } -func newSyncList(fragments []client.SyncFragment) syncList { +// processFragments categorizes items in sync fragments into a sync list. It also decrypts any +// encrypted data in sync fragments. +func processFragments(fragments []client.SyncFragment, cipherKey []byte) (syncList, error) { notes := map[string]client.SyncFragNote{} books := map[string]client.SyncFragBook{} expungedNotes := map[string]bool{} @@ -85,9 +86,23 @@ func newSyncList(fragments []client.SyncFragment) syncList { for _, fragment := range fragments { for _, note := range fragment.Notes { + log.Debug("decrypting note %s\n", note.UUID) + bodyDec, err := crypt.AesGcmDecrypt(cipherKey, note.Body) + if err != nil { + return syncList{}, errors.Wrapf(err, "decrypting body for note %s", note.UUID) + } + + note.Body = string(bodyDec) notes[note.UUID] = note } for _, book := range fragment.Books { + log.Debug("decrypting book %s\n", book.UUID) + labelDec, err := crypt.AesGcmDecrypt(cipherKey, book.Label) + if err != nil { + return syncList{}, errors.Wrapf(err, "decrypting label for book %s", book.UUID) + } + + book.Label = string(labelDec) books[book.UUID] = book } for _, uuid := range fragment.ExpungedBooks { @@ -105,7 +120,7 @@ func newSyncList(fragments []client.SyncFragment) syncList { } } - return syncList{ + sl := syncList{ Notes: notes, Books: books, ExpungedNotes: expungedNotes, @@ -113,30 +128,35 @@ func newSyncList(fragments []client.SyncFragment) syncList { MaxUSN: maxUSN, MaxCurrentTime: maxCurrentTime, } + + return sl, nil } // getSyncList gets a list of all sync fragments after the specified usn // and aggregates them into a syncList data structure -func getSyncList(ctx infra.DnoteCtx, apiKey string, afterUSN int) (syncList, error) { - fragments, err := getSyncFragments(ctx, apiKey, afterUSN) +func getSyncList(ctx infra.DnoteCtx, afterUSN int) (syncList, error) { + fragments, err := getSyncFragments(ctx, afterUSN) if err != nil { return syncList{}, errors.Wrap(err, "getting sync fragments") } - ret := newSyncList(fragments) + ret, err := processFragments(fragments, ctx.CipherKey) + if err != nil { + return syncList{}, errors.Wrap(err, "making sync list") + } return ret, nil } // getSyncFragments repeatedly gets all sync fragments after the specified usn until there is no more new data // remaining and returns the buffered list -func getSyncFragments(ctx infra.DnoteCtx, apiKey string, afterUSN int) ([]client.SyncFragment, error) { +func getSyncFragments(ctx infra.DnoteCtx, afterUSN int) ([]client.SyncFragment, error) { var buf []client.SyncFragment nextAfterUSN := afterUSN for { - resp, err := client.GetSyncFragment(ctx, apiKey, nextAfterUSN) + resp, err := client.GetSyncFragment(ctx, nextAfterUSN) if err != nil { return buf, errors.Wrap(err, "getting sync fragment") } @@ -159,7 +179,7 @@ func getSyncFragments(ctx infra.DnoteCtx, apiKey string, afterUSN int) ([]client // resolveLabel resolves a book label conflict by repeatedly appending an increasing integer // to the label until it finds a unique label. It returns the first non-conflicting label. -func resolveLabel(tx *sql.Tx, label string) (string, error) { +func resolveLabel(tx *infra.DB, label string) (string, error) { var ret string for i := 2; ; i++ { @@ -180,7 +200,7 @@ func resolveLabel(tx *sql.Tx, label string) (string, error) { // mergeBook inserts or updates the given book in the local database. // If a book with a duplicate label exists locally, it renames the duplicate by appending a number. -func mergeBook(tx *sql.Tx, b client.SyncFragBook, mode int) error { +func mergeBook(tx *infra.DB, b client.SyncFragBook, mode int) error { var count int if err := tx.QueryRow("SELECT count(*) FROM books WHERE label = ?", b.Label).Scan(&count); err != nil { return errors.Wrapf(err, "checking for books with a duplicate label %s", b.Label) @@ -214,7 +234,7 @@ func mergeBook(tx *sql.Tx, b client.SyncFragBook, mode int) error { return nil } -func stepSyncBook(tx *sql.Tx, b client.SyncFragBook) error { +func stepSyncBook(tx *infra.DB, b client.SyncFragBook) error { var localUSN int var dirty bool err := tx.QueryRow("SELECT usn, dirty FROM books WHERE uuid = ?", b.UUID).Scan(&localUSN, &dirty) @@ -238,7 +258,7 @@ func stepSyncBook(tx *sql.Tx, b client.SyncFragBook) error { return nil } -func mergeNote(tx *sql.Tx, serverNote client.SyncFragNote, localNote core.Note) error { +func mergeNote(tx *infra.DB, serverNote client.SyncFragNote, localNote core.Note) error { var bookDeleted bool err := tx.QueryRow("SELECT deleted FROM books WHERE uuid = ?", localNote.BookUUID).Scan(&bookDeleted) if err != nil { @@ -269,7 +289,7 @@ func mergeNote(tx *sql.Tx, serverNote client.SyncFragNote, localNote core.Note) return nil } -func stepSyncNote(tx *sql.Tx, n client.SyncFragNote) error { +func stepSyncNote(tx *infra.DB, n client.SyncFragNote) error { var localNote core.Note err := tx.QueryRow("SELECT usn, book_uuid, dirty, deleted FROM notes WHERE uuid = ?", n.UUID). Scan(&localNote.USN, &localNote.BookUUID, &localNote.Dirty, &localNote.Deleted) @@ -293,7 +313,7 @@ func stepSyncNote(tx *sql.Tx, n client.SyncFragNote) error { return nil } -func fullSyncNote(tx *sql.Tx, n client.SyncFragNote) error { +func fullSyncNote(tx *infra.DB, n client.SyncFragNote) error { var localNote core.Note err := tx.QueryRow("SELECT usn,book_uuid, dirty, deleted FROM notes WHERE uuid = ?", n.UUID). Scan(&localNote.USN, &localNote.BookUUID, &localNote.Dirty, &localNote.Deleted) @@ -317,7 +337,7 @@ func fullSyncNote(tx *sql.Tx, n client.SyncFragNote) error { return nil } -func syncDeleteNote(tx *sql.Tx, noteUUID string) error { +func syncDeleteNote(tx *infra.DB, noteUUID string) error { var localUSN int var dirty bool err := tx.QueryRow("SELECT usn, dirty FROM notes WHERE uuid = ?", noteUUID).Scan(&localUSN, &dirty) @@ -342,7 +362,7 @@ func syncDeleteNote(tx *sql.Tx, noteUUID string) error { } // checkNotesPristine checks that none of the notes in the given book are dirty -func checkNotesPristine(tx *sql.Tx, bookUUID string) (bool, error) { +func checkNotesPristine(tx *infra.DB, bookUUID string) (bool, error) { var count int if err := tx.QueryRow("SELECT count(*) FROM notes WHERE book_uuid = ? AND dirty = ?", bookUUID, true).Scan(&count); err != nil { return false, errors.Wrapf(err, "counting notes that are dirty in book %s", bookUUID) @@ -355,7 +375,7 @@ func checkNotesPristine(tx *sql.Tx, bookUUID string) (bool, error) { return true, nil } -func syncDeleteBook(tx *sql.Tx, bookUUID string) error { +func syncDeleteBook(tx *infra.DB, bookUUID string) error { var localUSN int var dirty bool err := tx.QueryRow("SELECT usn, dirty FROM books WHERE uuid = ?", bookUUID).Scan(&localUSN, &dirty) @@ -401,7 +421,7 @@ func syncDeleteBook(tx *sql.Tx, bookUUID string) error { return nil } -func fullSyncBook(tx *sql.Tx, b client.SyncFragBook) error { +func fullSyncBook(tx *infra.DB, b client.SyncFragBook) error { var localUSN int var dirty bool err := tx.QueryRow("SELECT usn, dirty FROM books WHERE uuid = ?", b.UUID).Scan(&localUSN, &dirty) @@ -453,7 +473,7 @@ func checkBookInList(uuid string, list *syncList) bool { // judging by the full list of resources in the server. Concretely, the only acceptable // situation in which a local note is not present in the server is if it is new and has not been // uploaded (i.e. dirty and usn is 0). Otherwise, it is a result of some kind of error and should be cleaned. -func cleanLocalNotes(tx *sql.Tx, fullList *syncList) error { +func cleanLocalNotes(tx *infra.DB, fullList *syncList) error { rows, err := tx.Query("SELECT uuid, usn, dirty FROM notes") if err != nil { return errors.Wrap(err, "getting local notes") @@ -479,7 +499,7 @@ func cleanLocalNotes(tx *sql.Tx, fullList *syncList) error { } // cleanLocalBooks deletes from the local database any books that are in invalid state -func cleanLocalBooks(tx *sql.Tx, fullList *syncList) error { +func cleanLocalBooks(tx *infra.DB, fullList *syncList) error { rows, err := tx.Query("SELECT uuid, usn, dirty FROM books") if err != nil { return errors.Wrap(err, "getting local books") @@ -504,11 +524,11 @@ func cleanLocalBooks(tx *sql.Tx, fullList *syncList) error { return nil } -func fullSync(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) error { +func fullSync(ctx infra.DnoteCtx, tx *infra.DB) error { log.Debug("performing a full sync\n") log.Info("resolving delta.") - list, err := getSyncList(ctx, apiKey, 0) + list, err := getSyncList(ctx, 0) if err != nil { return errors.Wrap(err, "getting sync list") } @@ -555,12 +575,12 @@ func fullSync(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) error { return nil } -func stepSync(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string, afterUSN int) error { +func stepSync(ctx infra.DnoteCtx, tx *infra.DB, afterUSN int) error { log.Debug("performing a step sync\n") log.Info("resolving delta.") - list, err := getSyncList(ctx, apiKey, afterUSN) + list, err := getSyncList(ctx, afterUSN) if err != nil { return errors.Wrap(err, "getting sync list") } @@ -599,7 +619,7 @@ func stepSync(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string, afterUSN int) error return nil } -func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { +func sendBooks(ctx infra.DnoteCtx, tx *infra.DB) (bool, error) { isBehind := false rows, err := tx.Query("SELECT uuid, label, usn, deleted FROM books WHERE dirty") @@ -629,7 +649,7 @@ func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { continue } else { - resp, err := client.CreateBook(ctx, apiKey, book.Label) + resp, err := client.CreateBook(ctx, book.Label) if err != nil { return isBehind, errors.Wrap(err, "creating a book") } @@ -655,7 +675,7 @@ func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { } } else { if book.Deleted { - resp, err := client.DeleteBook(ctx, apiKey, book.UUID) + resp, err := client.DeleteBook(ctx, book.UUID) if err != nil { return isBehind, errors.Wrap(err, "deleting a book") } @@ -667,7 +687,7 @@ func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { respUSN = resp.Book.USN } else { - resp, err := client.UpdateBook(ctx, apiKey, book.Label, book.UUID) + resp, err := client.UpdateBook(ctx, book.Label, book.UUID) if err != nil { return isBehind, errors.Wrap(err, "updating a book") } @@ -703,7 +723,7 @@ func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { return isBehind, nil } -func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { +func sendNotes(ctx infra.DnoteCtx, tx *infra.DB) (bool, error) { isBehind := false rows, err := tx.Query("SELECT uuid, book_uuid, body, public, deleted, usn FROM notes WHERE dirty") @@ -734,7 +754,7 @@ func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { continue } else { - resp, err := client.CreateNote(ctx, apiKey, note.BookUUID, note.Body) + resp, err := client.CreateNote(ctx, note.BookUUID, note.Body) if err != nil { return isBehind, errors.Wrap(err, "creating a note") } @@ -755,7 +775,7 @@ func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { } } else { if note.Deleted { - resp, err := client.DeleteNote(ctx, apiKey, note.UUID) + resp, err := client.DeleteNote(ctx, note.UUID) if err != nil { return isBehind, errors.Wrap(err, "deleting a note") } @@ -767,7 +787,7 @@ func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { respUSN = resp.Result.USN } else { - resp, err := client.UpdateNote(ctx, apiKey, note.UUID, note.BookUUID, note.Body, note.Public) + resp, err := client.UpdateNote(ctx, note.UUID, note.BookUUID, note.Body, note.Public) if err != nil { return isBehind, errors.Wrap(err, "updating a note") } @@ -803,7 +823,7 @@ func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { return isBehind, nil } -func sendChanges(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { +func sendChanges(ctx infra.DnoteCtx, tx *infra.DB) (bool, error) { log.Info("sending changes.") var delta int @@ -811,12 +831,12 @@ func sendChanges(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { fmt.Printf(" (total %d).", delta) - behind1, err := sendBooks(ctx, tx, apiKey) + behind1, err := sendBooks(ctx, tx) if err != nil { return behind1, errors.Wrap(err, "sending books") } - behind2, err := sendNotes(ctx, tx, apiKey) + behind2, err := sendNotes(ctx, tx) if err != nil { return behind2, errors.Wrap(err, "sending notes") } @@ -828,25 +848,23 @@ func sendChanges(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) { return isBehind, nil } -func updateLastMaxUSN(tx *sql.Tx, val int) error { - _, err := tx.Exec("UPDATE system SET value = ? WHERE key = ?", val, infra.SystemLastMaxUSN) - if err != nil { +func updateLastMaxUSN(tx *infra.DB, val int) error { + if err := core.UpdateSystem(tx, infra.SystemLastMaxUSN, val); err != nil { return errors.Wrapf(err, "updating %s", infra.SystemLastMaxUSN) } return nil } -func updateLastSyncAt(tx *sql.Tx, val int64) error { - _, err := tx.Exec("UPDATE system SET value = ? WHERE key = ?", val, infra.SystemLastSyncAt) - if err != nil { +func updateLastSyncAt(tx *infra.DB, val int64) error { + if err := core.UpdateSystem(tx, infra.SystemLastSyncAt, val); err != nil { return errors.Wrapf(err, "updating %s", infra.SystemLastSyncAt) } return nil } -func saveSyncState(tx *sql.Tx, serverTime int64, serverMaxUSN int) error { +func saveSyncState(tx *infra.DB, serverTime int64, serverMaxUSN int) error { if err := updateLastMaxUSN(tx, serverMaxUSN); err != nil { return errors.Wrap(err, "updating last max usn") } @@ -859,26 +877,20 @@ func saveSyncState(tx *sql.Tx, serverTime int64, serverMaxUSN int) error { func newRun(ctx infra.DnoteCtx) core.RunEFunc { return func(cmd *cobra.Command, args []string) error { - config, err := core.ReadConfig(ctx) - if err != nil { - return errors.Wrap(err, "reading the config") - } - if config.APIKey == "" { - log.Error("login required. please run `dnote login`\n") - return nil + if ctx.SessionKey == "" || ctx.CipherKey == nil { + return errors.New("not logged in") } if err := migrate.Run(ctx, migrate.RemoteSequence, migrate.RemoteMode); err != nil { return errors.Wrap(err, "running remote migrations") } - db := ctx.DB - tx, err := db.Begin() + tx, err := ctx.DB.Begin() if err != nil { return errors.Wrap(err, "beginning a transaction") } - syncState, err := client.GetSyncState(config.APIKey, ctx) + syncState, err := client.GetSyncState(ctx) if err != nil { return errors.Wrap(err, "getting the sync state from the server") } @@ -895,9 +907,9 @@ func newRun(ctx infra.DnoteCtx) core.RunEFunc { var syncErr error if isFullSync || lastSyncAt < syncState.FullSyncBefore { - syncErr = fullSync(ctx, tx, config.APIKey) + syncErr = fullSync(ctx, tx) } else if lastMaxUSN != syncState.MaxUSN { - syncErr = stepSync(ctx, tx, config.APIKey, lastMaxUSN) + syncErr = stepSync(ctx, tx, lastMaxUSN) } else { // if no need to sync from the server, simply update the last sync timestamp and proceed to send changes err = updateLastSyncAt(tx, syncState.CurrentTime) @@ -907,10 +919,10 @@ func newRun(ctx infra.DnoteCtx) core.RunEFunc { } if syncErr != nil { tx.Rollback() - return errors.Wrap(err, "syncing changes from the server") + return errors.Wrap(syncErr, "syncing changes from the server") } - isBehind, err := sendChanges(ctx, tx, config.APIKey) + isBehind, err := sendChanges(ctx, tx) if err != nil { tx.Rollback() return errors.Wrap(err, "sending changes") @@ -926,7 +938,7 @@ func newRun(ctx infra.DnoteCtx) core.RunEFunc { return errors.Wrap(err, "getting the new last max_usn") } - err = stepSync(ctx, tx, config.APIKey, updatedLastMaxUSN) + err = stepSync(ctx, tx, updatedLastMaxUSN) if err != nil { tx.Rollback() return errors.Wrap(err, "performing the follow-up step sync") diff --git a/cmd/sync/sync_test.go b/cmd/sync/sync_test.go index 9ca9b44a..3fcd201c 100644 --- a/cmd/sync/sync_test.go +++ b/cmd/sync/sync_test.go @@ -11,12 +11,87 @@ import ( "github.com/dnote/cli/client" "github.com/dnote/cli/core" + "github.com/dnote/cli/crypt" "github.com/dnote/cli/infra" "github.com/dnote/cli/testutils" "github.com/dnote/cli/utils" "github.com/pkg/errors" ) +var cipherKey = []byte("AES256Key-32Characters1234567890") + +func TestProcessFragments(t *testing.T) { + // set up + ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + defer testutils.TeardownEnv(ctx) + + fragments := []client.SyncFragment{ + client.SyncFragment{ + FragMaxUSN: 10, + UserMaxUSN: 10, + CurrentTime: 1550436136, + Notes: []client.SyncFragNote{ + client.SyncFragNote{ + UUID: "45546de0-40ed-45cf-9bfc-62ce729a7d3d", + Body: "7GgIppDdxDn+4DUoVoLXbncZDRqXGwbDVNF/eCssu+1BXMdq+HAziJHGgK7drdcIBtYDDXj0OwHz9dQDDOyWeNqkLWEIQ2Roygs229dRxdO3Z6ST+qSOr/9TTjDlFxydF5Ps7nAXdN9KVxH8FKIZDsxJ45qeLKpQK/6poAM39BCOiysqAXJQz9ngOJiqImAuftS6d/XhwX77QvnM91VCKK0tFmsMdDDw0J9QMwnlYU1CViHy1Hdhhcf9Ea38Mj4SCrWMPscXyP2fpAu5ukbIK3vS2pvbnH5vC8ZuvihrQif1BsiwfYmN981mLYs069Dn4B72qcXPwU7qrN3V0k57JGcAlTiEoOD5QowyraensQlR1doorLb43SjTiJLItougn5K5QPRiHuNxfv39pa7A0gKA1n/3UhG/SBuCpDuPYjwmBkvkzCKJNgpbLQ8p29JXMQcWrm4e9GfnVjMhAEtxttIta3MN6EcYG7cB1dJ04OLYVcJuRA==", + }, + client.SyncFragNote{ + UUID: "a25a5336-afe9-46c4-b881-acab911c0bc3", + Body: "WGzcYA6kLuUFEU7HLTDJt7UWF7fEmbCPHfC16VBrAyfT2wDejXbIuFpU5L7g0aU=", + }, + }, + Books: []client.SyncFragBook{ + client.SyncFragBook{ + UUID: "e8ac6f25-d95b-435a-9fae-094f7506a5ac", + Label: "qBrSrAcnTUHu51bIrv6jSA/dNffr/kRlIg+MklxeQQ==", + }, + client.SyncFragBook{ + UUID: "05fd8b95-ddcd-4071-9380-4358ffb8a436", + Label: "uHWoBFdKT78gTkFR7qhyzZkrn59c8ktEa8idrLkksKzIQ3VVAXxq0QZp7Uc=", + }, + }, + ExpungedNotes: []string{}, + ExpungedBooks: []string{}, + }, + } + + // exec + sl, err := processFragments(fragments, cipherKey) + if err != nil { + t.Fatalf(errors.Wrap(err, "executing").Error()) + } + + expected := syncList{ + Notes: map[string]client.SyncFragNote{ + "45546de0-40ed-45cf-9bfc-62ce729a7d3d": client.SyncFragNote{ + UUID: "45546de0-40ed-45cf-9bfc-62ce729a7d3d", + Body: "Lorem ipsum dolor sit amet, consectetur adipiscing elit.\n Donec ac libero efficitur, posuere dui non, egestas lectus.\n Aliquam urna ligula, sagittis eu volutpat vel, consequat et augue.\n\n Ut mi urna, dignissim a ex eget, venenatis accumsan sem. Praesent facilisis, ligula hendrerit auctor varius, mauris metus hendrerit dolor, sit amet pulvinar.", + }, + "a25a5336-afe9-46c4-b881-acab911c0bc3": client.SyncFragNote{ + UUID: "a25a5336-afe9-46c4-b881-acab911c0bc3", + Body: "foo bar baz quz\nqux", + }, + }, + Books: map[string]client.SyncFragBook{ + "e8ac6f25-d95b-435a-9fae-094f7506a5ac": client.SyncFragBook{ + UUID: "e8ac6f25-d95b-435a-9fae-094f7506a5ac", + Label: "foo", + }, + "05fd8b95-ddcd-4071-9380-4358ffb8a436": client.SyncFragBook{ + UUID: "05fd8b95-ddcd-4071-9380-4358ffb8a436", + Label: "foo-bar-baz-1000", + }, + }, + ExpungedNotes: map[string]bool{}, + ExpungedBooks: map[string]bool{}, + MaxUSN: 10, + MaxCurrentTime: 1550436136, + } + + // test + testutils.AssertDeepEqual(t, sl, expected, "syncList mismatch") +} + func TestGetLastSyncAt(t *testing.T) { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) @@ -1769,6 +1844,7 @@ func TestMergeBook(t *testing.T) { func TestSaveServerState(t *testing.T) { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) defer testutils.TeardownEnv(ctx) db := ctx.DB @@ -1812,6 +1888,7 @@ func TestSaveServerState(t *testing.T) { func TestSendBooks(t *testing.T) { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) defer testutils.TeardownEnv(ctx) db := ctx.DB @@ -1846,9 +1923,9 @@ func TestSendBooks(t *testing.T) { var updatesUUIDs []string var deletedUUIDs []string - // fire up a test server + // fire up a test server. It decrypts the payload for test purposes. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() == "/v1/books" && r.Method == "POST" { + if r.URL.String() == "/v2/books" && r.Method == "POST" { var payload client.CreateBookPayload err := json.NewDecoder(r.Body).Decode(&payload) @@ -1857,15 +1934,22 @@ func TestSendBooks(t *testing.T) { return } - createdLabels = append(createdLabels, payload.Name) + labelDec, err := crypt.AesGcmDecrypt(cipherKey, payload.Name) + if err != nil { + t.Fatalf(errors.Wrap(err, "decrypting label").Error()) + } + + labelDecStr := string(labelDec) + + createdLabels = append(createdLabels, labelDecStr) resp := client.CreateBookResp{ Book: client.RespBook{ - UUID: fmt.Sprintf("server-%s-uuid", payload.Name), + UUID: fmt.Sprintf("server-%s-uuid", labelDecStr), }, } - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -1879,14 +1963,14 @@ func TestSendBooks(t *testing.T) { uuid := p[3] updatesUUIDs = append(updatesUUIDs, uuid) - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) return } else if r.Method == "DELETE" { uuid := p[3] deletedUUIDs = append(deletedUUIDs, uuid) - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) return } @@ -1904,13 +1988,16 @@ func TestSendBooks(t *testing.T) { t.Fatalf(errors.Wrap(err, "beginning a transaction").Error()) } - if _, err := sendBooks(ctx, tx, "mockAPIKey"); err != nil { + if _, err := sendBooks(ctx, tx); err != nil { tx.Rollback() t.Fatalf(errors.Wrap(err, "executing").Error()) } tx.Commit() + // test + + // First, decrypt data so that they can be asserted sort.SliceStable(createdLabels, func(i, j int) bool { return strings.Compare(createdLabels[i], createdLabels[j]) < 0 }) @@ -1964,7 +2051,7 @@ func TestSendBooks(t *testing.T) { func TestSendBooks_isBehind(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() == "/v1/books" && r.Method == "POST" { + if r.URL.String() == "/v2/books" && r.Method == "POST" { var payload client.CreateBookPayload err := json.NewDecoder(r.Body).Decode(&payload) @@ -1979,7 +2066,7 @@ func TestSendBooks_isBehind(t *testing.T) { }, } - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -1996,7 +2083,7 @@ func TestSendBooks_isBehind(t *testing.T) { }, } - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2009,7 +2096,7 @@ func TestSendBooks_isBehind(t *testing.T) { }, } - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2041,6 +2128,7 @@ func TestSendBooks_isBehind(t *testing.T) { func() { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) ctx.APIEndpoint = ts.URL defer testutils.TeardownEnv(ctx) @@ -2055,7 +2143,7 @@ func TestSendBooks_isBehind(t *testing.T) { t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error()) } - isBehind, err := sendBooks(ctx, tx, "mockAPIKey") + isBehind, err := sendBooks(ctx, tx) if err != nil { tx.Rollback() t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error()) @@ -2088,6 +2176,7 @@ func TestSendBooks_isBehind(t *testing.T) { func() { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) ctx.APIEndpoint = ts.URL defer testutils.TeardownEnv(ctx) @@ -2102,7 +2191,7 @@ func TestSendBooks_isBehind(t *testing.T) { t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error()) } - isBehind, err := sendBooks(ctx, tx, "mockAPIKey") + isBehind, err := sendBooks(ctx, tx) if err != nil { tx.Rollback() t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error()) @@ -2135,6 +2224,7 @@ func TestSendBooks_isBehind(t *testing.T) { func() { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) ctx.APIEndpoint = ts.URL defer testutils.TeardownEnv(ctx) @@ -2149,7 +2239,7 @@ func TestSendBooks_isBehind(t *testing.T) { t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error()) } - isBehind, err := sendBooks(ctx, tx, "mockAPIKey") + isBehind, err := sendBooks(ctx, tx) if err != nil { tx.Rollback() t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error()) @@ -2169,6 +2259,7 @@ func TestSendBooks_isBehind(t *testing.T) { func TestSendNotes(t *testing.T) { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) defer testutils.TeardownEnv(ctx) db := ctx.DB @@ -2203,9 +2294,9 @@ func TestSendNotes(t *testing.T) { var updatedUUIDs []string var deletedUUIDs []string - // fire up a test server + // fire up a test server. It decrypts the payload for test purposes. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() == "/v1/notes" && r.Method == "POST" { + if r.URL.String() == "/v2/notes" && r.Method == "POST" { var payload client.CreateNotePayload err := json.NewDecoder(r.Body).Decode(&payload) @@ -2214,15 +2305,21 @@ func TestSendNotes(t *testing.T) { return } - createdBodys = append(createdBodys, payload.Body) + bodyDec, err := crypt.AesGcmDecrypt(cipherKey, payload.Body) + if err != nil { + t.Fatalf(errors.Wrap(err, "decrypting body").Error()) + } + bodyDecStr := string(bodyDec) + + createdBodys = append(createdBodys, bodyDecStr) resp := client.CreateNoteResp{ Result: client.RespNote{ - UUID: fmt.Sprintf("server-%s-uuid", payload.Body), + UUID: fmt.Sprintf("server-%s-uuid", bodyDecStr), }, } - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2236,14 +2333,14 @@ func TestSendNotes(t *testing.T) { uuid := p[3] updatedUUIDs = append(updatedUUIDs, uuid) - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) return } else if r.Method == "DELETE" { uuid := p[3] deletedUUIDs = append(deletedUUIDs, uuid) - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) return } @@ -2261,7 +2358,7 @@ func TestSendNotes(t *testing.T) { t.Fatalf(errors.Wrap(err, "beginning a transaction").Error()) } - if _, err := sendNotes(ctx, tx, "mockAPIKey"); err != nil { + if _, err := sendNotes(ctx, tx); err != nil { tx.Rollback() t.Fatalf(errors.Wrap(err, "executing").Error()) } @@ -2327,7 +2424,7 @@ func TestSendNotes_isBehind(t *testing.T) { }, } - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2344,7 +2441,7 @@ func TestSendNotes_isBehind(t *testing.T) { }, } - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2357,7 +2454,7 @@ func TestSendNotes_isBehind(t *testing.T) { }, } - w.Header().Set("Body-Type", "application/json") + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2389,6 +2486,7 @@ func TestSendNotes_isBehind(t *testing.T) { func() { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) ctx.APIEndpoint = ts.URL defer testutils.TeardownEnv(ctx) @@ -2404,7 +2502,7 @@ func TestSendNotes_isBehind(t *testing.T) { t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error()) } - isBehind, err := sendNotes(ctx, tx, "mockAPIKey") + isBehind, err := sendNotes(ctx, tx) if err != nil { tx.Rollback() t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error()) @@ -2437,6 +2535,7 @@ func TestSendNotes_isBehind(t *testing.T) { func() { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) ctx.APIEndpoint = ts.URL defer testutils.TeardownEnv(ctx) @@ -2452,7 +2551,7 @@ func TestSendNotes_isBehind(t *testing.T) { t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error()) } - isBehind, err := sendNotes(ctx, tx, "mockAPIKey") + isBehind, err := sendNotes(ctx, tx) if err != nil { tx.Rollback() t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error()) @@ -2485,6 +2584,7 @@ func TestSendNotes_isBehind(t *testing.T) { func() { // set up ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true) + testutils.Login(t, &ctx) ctx.APIEndpoint = ts.URL defer testutils.TeardownEnv(ctx) @@ -2500,7 +2600,7 @@ func TestSendNotes_isBehind(t *testing.T) { t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error()) } - isBehind, err := sendNotes(ctx, tx, "mockAPIKey") + isBehind, err := sendNotes(ctx, tx) if err != nil { tx.Rollback() t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error()) diff --git a/core/core.go b/core/core.go index ad94d574..7bccca9e 100644 --- a/core/core.go +++ b/core/core.go @@ -11,6 +11,7 @@ import ( "time" "github.com/dnote/cli/infra" + "github.com/dnote/cli/log" "github.com/dnote/cli/utils" "github.com/pkg/errors" "github.com/satori/go.uuid" @@ -267,9 +268,9 @@ func GetEditorInput(ctx infra.DnoteCtx, fpath string, content *string) error { return nil } -func initSystemKV(tx *sql.Tx, key string, val string) error { +func initSystemKV(db *infra.DB, key string, val string) error { var count int - if err := tx.QueryRow("SELECT count(*) FROM system WHERE key = ?", key).Scan(&count); err != nil { + if err := db.QueryRow("SELECT count(*) FROM system WHERE key = ?", key).Scan(&count); err != nil { return errors.Wrapf(err, "counting %s", key) } @@ -277,8 +278,8 @@ func initSystemKV(tx *sql.Tx, key string, val string) error { return nil } - if _, err := tx.Exec("INSERT INTO system (key, value) VALUES (?, ?)", key, val); err != nil { - tx.Rollback() + if _, err := db.Exec("INSERT INTO system (key, value) VALUES (?, ?)", key, val); err != nil { + db.Rollback() return errors.Wrapf(err, "inserting %s %s", key, val) } @@ -310,3 +311,30 @@ func InitSystem(ctx infra.DnoteCtx) error { return nil } + +// GetValidSession returns a session key from the local storage if one exists and is not expired +// If one does not exist or is expired, it prints out an instruction and returns false +func GetValidSession(ctx infra.DnoteCtx) (string, bool, error) { + db := ctx.DB + + var sessionKey string + var sessionKeyExpires int64 + + if err := GetSystem(db, infra.SystemSessionKey, &sessionKey); err != nil { + return "", false, errors.Wrap(err, "getting session key") + } + if err := GetSystem(db, infra.SystemSessionKeyExpiry, &sessionKeyExpires); err != nil { + return "", false, errors.Wrap(err, "getting session key expiry") + } + + if sessionKey == "" { + log.Error("login required. please run `dnote login`\n") + return "", false, nil + } + if sessionKeyExpires < time.Now().Unix() { + log.Error("sesison expired. please run `dnote login`\n") + return "", false, nil + } + + return sessionKey, true, nil +} diff --git a/core/models.go b/core/models.go index 38e91036..6f8548cb 100644 --- a/core/models.go +++ b/core/models.go @@ -1,8 +1,7 @@ package core import ( - "database/sql" - + "github.com/dnote/cli/infra" "github.com/pkg/errors" ) @@ -45,8 +44,8 @@ func NewNote(uuid, bookUUID, body string, addedOn, editedOn int64, usn int, publ } // Insert inserts a new note -func (n Note) Insert(tx *sql.Tx) error { - _, err := tx.Exec("INSERT INTO notes (uuid, book_uuid, body, added_on, edited_on, usn, public, deleted, dirty) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", +func (n Note) Insert(db *infra.DB) error { + _, err := db.Exec("INSERT INTO notes (uuid, book_uuid, body, added_on, edited_on, usn, public, deleted, dirty) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", n.UUID, n.BookUUID, n.Body, n.AddedOn, n.EditedOn, n.USN, n.Public, n.Deleted, n.Dirty) if err != nil { @@ -57,8 +56,8 @@ func (n Note) Insert(tx *sql.Tx) error { } // Update updates the note with the given data -func (n Note) Update(tx *sql.Tx) error { - _, err := tx.Exec("UPDATE notes SET book_uuid = ?, body = ?, added_on = ?, edited_on = ?, usn = ?, public = ?, deleted = ?, dirty = ? WHERE uuid = ?", +func (n Note) Update(db *infra.DB) error { + _, err := db.Exec("UPDATE notes SET book_uuid = ?, body = ?, added_on = ?, edited_on = ?, usn = ?, public = ?, deleted = ?, dirty = ? WHERE uuid = ?", n.BookUUID, n.Body, n.AddedOn, n.EditedOn, n.USN, n.Public, n.Deleted, n.Dirty, n.UUID) if err != nil { @@ -69,8 +68,8 @@ func (n Note) Update(tx *sql.Tx) error { } // UpdateUUID updates the uuid of a book -func (n *Note) UpdateUUID(tx *sql.Tx, newUUID string) error { - _, err := tx.Exec("UPDATE notes SET uuid = ? WHERE uuid = ?", newUUID, n.UUID) +func (n *Note) UpdateUUID(db *infra.DB, newUUID string) error { + _, err := db.Exec("UPDATE notes SET uuid = ? WHERE uuid = ?", newUUID, n.UUID) if err != nil { return errors.Wrapf(err, "updating note uuid from '%s' to '%s'", n.UUID, newUUID) @@ -82,8 +81,8 @@ func (n *Note) UpdateUUID(tx *sql.Tx, newUUID string) error { } // Expunge hard-deletes the note from the database -func (n Note) Expunge(tx *sql.Tx) error { - _, err := tx.Exec("DELETE FROM notes WHERE uuid = ?", n.UUID) +func (n Note) Expunge(db *infra.DB) error { + _, err := db.Exec("DELETE FROM notes WHERE uuid = ?", n.UUID) if err != nil { return errors.Wrap(err, "expunging a note locally") } @@ -103,8 +102,8 @@ func NewBook(uuid, label string, usn int, deleted, dirty bool) Book { } // Insert inserts a new book -func (b Book) Insert(tx *sql.Tx) error { - _, err := tx.Exec("INSERT INTO books (uuid, label, usn, dirty, deleted) VALUES (?, ?, ?, ?, ?)", +func (b Book) Insert(db *infra.DB) error { + _, err := db.Exec("INSERT INTO books (uuid, label, usn, dirty, deleted) VALUES (?, ?, ?, ?, ?)", b.UUID, b.Label, b.USN, b.Dirty, b.Deleted) if err != nil { @@ -115,8 +114,8 @@ func (b Book) Insert(tx *sql.Tx) error { } // Update updates the book with the given data -func (b Book) Update(tx *sql.Tx) error { - _, err := tx.Exec("UPDATE books SET label = ?, usn = ?, dirty = ?, deleted = ? WHERE uuid = ?", +func (b Book) Update(db *infra.DB) error { + _, err := db.Exec("UPDATE books SET label = ?, usn = ?, dirty = ?, deleted = ? WHERE uuid = ?", b.Label, b.USN, b.Dirty, b.Deleted, b.UUID) if err != nil { @@ -127,8 +126,8 @@ func (b Book) Update(tx *sql.Tx) error { } // UpdateUUID updates the uuid of a book -func (b *Book) UpdateUUID(tx *sql.Tx, newUUID string) error { - _, err := tx.Exec("UPDATE books SET uuid = ? WHERE uuid = ?", newUUID, b.UUID) +func (b *Book) UpdateUUID(db *infra.DB, newUUID string) error { + _, err := db.Exec("UPDATE books SET uuid = ? WHERE uuid = ?", newUUID, b.UUID) if err != nil { return errors.Wrapf(err, "updating book uuid from '%s' to '%s'", b.UUID, newUUID) @@ -140,8 +139,8 @@ func (b *Book) UpdateUUID(tx *sql.Tx, newUUID string) error { } // Expunge hard-deletes the book from the database -func (b Book) Expunge(tx *sql.Tx) error { - _, err := tx.Exec("DELETE FROM books WHERE uuid = ?", b.UUID) +func (b Book) Expunge(db *infra.DB) error { + _, err := db.Exec("DELETE FROM books WHERE uuid = ?", b.UUID) if err != nil { return errors.Wrap(err, "expunging a book locally") } diff --git a/core/operations.go b/core/operations.go new file mode 100644 index 00000000..92d617c7 --- /dev/null +++ b/core/operations.go @@ -0,0 +1,85 @@ +package core + +import ( + "encoding/base64" + + "github.com/dnote/cli/infra" + "github.com/pkg/errors" +) + +// InsertSystem inserets a system configuration +func InsertSystem(db *infra.DB, key, val string) error { + if _, err := db.Exec("INSERT INTO system (key, value) VALUES (? , ?);", key, val); err != nil { + return errors.Wrap(err, "saving system config") + } + + return nil +} + +// UpsertSystem inserts or updates a system configuration +func UpsertSystem(db *infra.DB, key, val string) error { + var count int + if err := db.QueryRow("SELECT count(*) FROM system WHERE key = ?", key).Scan(&count); err != nil { + return errors.Wrap(err, "counting system record") + } + + if count == 0 { + if _, err := db.Exec("INSERT INTO system (key, value) VALUES (? , ?);", key, val); err != nil { + return errors.Wrap(err, "saving system config") + } + } else { + if _, err := db.Exec("UPDATE system SET value = ? WHERE key = ?", val, key); err != nil { + return errors.Wrap(err, "updating system config") + } + } + + return nil +} + +// UpdateSystem updates a system configuration +func UpdateSystem(db *infra.DB, key, val interface{}) error { + if _, err := db.Exec("UPDATE system SET value = ? WHERE key = ?", val, key); err != nil { + return errors.Wrap(err, "updating system config") + } + + return nil +} + +// GetSystem scans the given system configuration record onto the destination +func GetSystem(db *infra.DB, key string, dest interface{}) error { + if err := db.QueryRow("SELECT value FROM system WHERE key = ?", key).Scan(dest); err != nil { + return errors.Wrap(err, "finding system configuration record") + } + + return nil +} + +// DeleteSystem delets the given system record +func DeleteSystem(db *infra.DB, key string) error { + if _, err := db.Exec("DELETE FROM system WHERE key = ?", key); err != nil { + return errors.Wrap(err, "deleting system config") + } + + return nil +} + +// GetCipherKey retrieves the cipher key and decode the base64 into bytes. +func GetCipherKey(ctx infra.DnoteCtx) ([]byte, error) { + db, err := ctx.DB.Begin() + if err != nil { + return nil, errors.Wrap(err, "beginning transaction") + } + + var cipherKeyB64 string + err = GetSystem(db, infra.SystemCipherKey, &cipherKeyB64) + if err != nil { + return []byte{}, errors.Wrap(err, "getting enc key") + } + + cipherKey, err := base64.StdEncoding.DecodeString(cipherKeyB64) + if err != nil { + return nil, errors.Wrap(err, "decoding cipherKey from base64") + } + + return cipherKey, nil +} diff --git a/core/operations_test.go b/core/operations_test.go new file mode 100644 index 00000000..51070a48 --- /dev/null +++ b/core/operations_test.go @@ -0,0 +1,221 @@ +package core + +import ( + "fmt" + "testing" + + "github.com/dnote/cli/testutils" + "github.com/pkg/errors" +) + +func TestInsertSystem(t *testing.T) { + testCases := []struct { + key string + val string + }{ + { + key: "foo", + val: "1558089284", + }, + { + key: "baz", + val: "quz", + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("insert %s %s", tc.key, tc.val), func(t *testing.T) { + // Setup + ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true) + defer testutils.TeardownEnv(ctx) + + // execute + db := ctx.DB + + tx, err := db.Begin() + if err != nil { + t.Fatalf(errors.Wrap(err, "beginning a transaction").Error()) + } + + if err := InsertSystem(tx, tc.key, tc.val); err != nil { + tx.Rollback() + t.Fatalf(errors.Wrap(err, "executing for test case").Error()) + } + + tx.Commit() + + // test + var key, val string + testutils.MustScan(t, "getting the saved record", + db.QueryRow("SELECT key, value FROM system WHERE key = ?", tc.key), &key, &val) + + testutils.AssertEqual(t, key, tc.key, "key mismatch for test case") + testutils.AssertEqual(t, val, tc.val, "val mismatch for test case") + }) + } +} + +func TestUpsertSystem(t *testing.T) { + testCases := []struct { + key string + val string + countDelta int + }{ + { + key: "foo", + val: "1558089284", + countDelta: 1, + }, + { + key: "baz", + val: "quz2", + countDelta: 0, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("insert %s %s", tc.key, tc.val), func(t *testing.T) { + // Setup + ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true) + defer testutils.TeardownEnv(ctx) + + db := ctx.DB + testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "baz", "quz") + + var initialSystemCount int + testutils.MustScan(t, "counting records", db.QueryRow("SELECT count(*) FROM system"), &initialSystemCount) + + // execute + tx, err := db.Begin() + if err != nil { + t.Fatalf(errors.Wrap(err, "beginning a transaction").Error()) + } + + if err := UpsertSystem(tx, tc.key, tc.val); err != nil { + tx.Rollback() + t.Fatalf(errors.Wrap(err, "executing for test case").Error()) + } + + tx.Commit() + + // test + var key, val string + testutils.MustScan(t, "getting the saved record", + db.QueryRow("SELECT key, value FROM system WHERE key = ?", tc.key), &key, &val) + var systemCount int + testutils.MustScan(t, "counting records", + db.QueryRow("SELECT count(*) FROM system"), &systemCount) + + testutils.AssertEqual(t, key, tc.key, "key mismatch") + testutils.AssertEqual(t, val, tc.val, "val mismatch") + testutils.AssertEqual(t, systemCount, initialSystemCount+tc.countDelta, "count mismatch") + }) + } +} + +func TestGetSystem(t *testing.T) { + t.Run(fmt.Sprintf("get string value"), func(t *testing.T) { + // Setup + ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true) + defer testutils.TeardownEnv(ctx) + + // execute + db := ctx.DB + testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "foo", "bar") + + tx, err := db.Begin() + if err != nil { + t.Fatalf(errors.Wrap(err, "beginning a transaction").Error()) + } + var dest string + if err := GetSystem(tx, "foo", &dest); err != nil { + tx.Rollback() + t.Fatalf(errors.Wrap(err, "executing for test case").Error()) + } + tx.Commit() + + // test + testutils.AssertEqual(t, dest, "bar", "dest mismatch") + }) + + t.Run(fmt.Sprintf("get int64 value"), func(t *testing.T) { + // Setup + ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true) + defer testutils.TeardownEnv(ctx) + + // execute + db := ctx.DB + testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "foo", 1234) + + tx, err := db.Begin() + if err != nil { + t.Fatalf(errors.Wrap(err, "beginning a transaction").Error()) + } + var dest int64 + if err := GetSystem(tx, "foo", &dest); err != nil { + tx.Rollback() + t.Fatalf(errors.Wrap(err, "executing for test case").Error()) + } + tx.Commit() + + // test + testutils.AssertEqual(t, dest, int64(1234), "dest mismatch") + }) +} + +func TestUpdateSystem(t *testing.T) { + testCases := []struct { + key string + val string + countDelta int + }{ + { + key: "foo", + val: "1558089284", + }, + { + key: "foo", + val: "bar", + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("update %s %s", tc.key, tc.val), func(t *testing.T) { + // Setup + ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true) + defer testutils.TeardownEnv(ctx) + + db := ctx.DB + testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "foo", "fuz") + testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "baz", "quz") + + var initialSystemCount int + testutils.MustScan(t, "counting records", db.QueryRow("SELECT count(*) FROM system"), &initialSystemCount) + + // execute + tx, err := db.Begin() + if err != nil { + t.Fatalf(errors.Wrap(err, "beginning a transaction").Error()) + } + + if err := UpdateSystem(tx, tc.key, tc.val); err != nil { + tx.Rollback() + t.Fatalf(errors.Wrap(err, "executing for test case").Error()) + } + + tx.Commit() + + // test + var key, val string + testutils.MustScan(t, "getting the saved record", + db.QueryRow("SELECT key, value FROM system WHERE key = ?", tc.key), &key, &val) + var systemCount int + testutils.MustScan(t, "counting records", + db.QueryRow("SELECT count(*) FROM system"), &systemCount) + + testutils.AssertEqual(t, key, tc.key, "key mismatch") + testutils.AssertEqual(t, val, tc.val, "val mismatch") + testutils.AssertEqual(t, systemCount, initialSystemCount, "count mismatch") + }) + } +} diff --git a/crypt/utils.go b/crypt/utils.go new file mode 100644 index 00000000..abee0c59 --- /dev/null +++ b/crypt/utils.go @@ -0,0 +1,107 @@ +// Package crypt provides cryptographic funcitonalities +package crypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "io" + + "github.com/dnote/cli/log" + "github.com/pkg/errors" + "golang.org/x/crypto/hkdf" + "golang.org/x/crypto/pbkdf2" +) + +var aesGcmNonceSize = 12 + +func runHkdf(secret, salt, info []byte) ([]byte, error) { + r := hkdf.New(sha256.New, secret, salt, info) + + ret := make([]byte, 32) + _, err := io.ReadFull(r, ret) + if err != nil { + return []byte{}, errors.Wrap(err, "reading key bytes") + } + + return ret, nil +} + +// MakeKeys derives, from the given credential, a key set comprising of an encryption key +// and an authentication key +func MakeKeys(password, email []byte, iteration int) ([]byte, []byte, error) { + masterKey := pbkdf2.Key([]byte(password), []byte(email), iteration, 32, sha256.New) + log.Debug("email: %s, password: %s", email, password) + + authKey, err := runHkdf(masterKey, email, []byte("auth")) + if err != nil { + return nil, nil, errors.Wrap(err, "deriving auth key") + } + + return masterKey, authKey, nil +} + +// AesGcmEncrypt encrypts the plaintext using AES in a GCM mode. It returns +// a ciphertext prepended by a 12 byte pseudo-random nonce, encoded in base64. +func AesGcmEncrypt(key, plaintext []byte) (string, error) { + if key == nil { + return "", errors.New("no key provided") + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", errors.Wrap(err, "initializing aes") + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return "", errors.Wrap(err, "initializing gcm") + } + + nonce := make([]byte, aesGcmNonceSize) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", errors.Wrap(err, "generating nonce") + } + + ciphertext := aesgcm.Seal(nonce, nonce, []byte(plaintext), nil) + cipherKeyB64 := base64.StdEncoding.EncodeToString(ciphertext) + + return cipherKeyB64, nil +} + +// AesGcmDecrypt decrypts the encrypted data using AES in a GCM mode. The data should be +// a base64 encoded string in the format of 12 byte nonce followed by a ciphertext. +func AesGcmDecrypt(key []byte, dataB64 string) ([]byte, error) { + if key == nil { + return nil, errors.New("no key provided") + } + + data, err := base64.StdEncoding.DecodeString(dataB64) + if err != nil { + return nil, errors.Wrap(err, "decoding base64 data") + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, errors.Wrap(err, "initializing aes") + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return nil, errors.Wrap(err, "initializing gcm") + } + + if len(data) < aesGcmNonceSize { + return nil, errors.Wrap(err, "malformed data") + } + + nonce, ciphertext := data[:aesGcmNonceSize], data[aesGcmNonceSize:] + plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, errors.Wrap(err, "decrypting") + } + + return plaintext, nil +} diff --git a/crypt/utils_test.go b/crypt/utils_test.go new file mode 100644 index 00000000..d6448a2f --- /dev/null +++ b/crypt/utils_test.go @@ -0,0 +1,100 @@ +package crypt + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "fmt" + "testing" + + "github.com/dnote/cli/testutils" + "github.com/pkg/errors" +) + +func TestAesGcmEncrypt(t *testing.T) { + testCases := []struct { + key []byte + plaintext []byte + }{ + { + key: []byte("AES256Key-32Characters1234567890"), + plaintext: []byte("foo bar baz quz"), + }, + { + key: []byte("AES256Key-32Charactersabcdefghij"), + plaintext: []byte("1234 foo 5678 bar 7890 baz"), + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("key %s plaintext %s", tc.key, tc.plaintext), func(t *testing.T) { + // encrypt + dataB64, err := AesGcmEncrypt(tc.key, tc.plaintext) + if err != nil { + t.Fatal(errors.Wrap(err, "performing encryption")) + } + + // test that data can be decrypted + data, err := base64.StdEncoding.DecodeString(dataB64) + if err != nil { + t.Fatal(errors.Wrap(err, "decoding data from base64")) + } + + nonce, ciphertext := data[:12], data[12:] + + fmt.Println(string(data)) + + block, err := aes.NewCipher([]byte(tc.key)) + if err != nil { + t.Fatal(errors.Wrap(err, "initializing aes")) + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + t.Fatal(errors.Wrap(err, "initializing gcm")) + } + + plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + t.Fatal(errors.Wrap(err, "decode")) + } + + testutils.AssertDeepEqual(t, plaintext, tc.plaintext, "plaintext mismatch") + }) + } +} + +func TestAesGcmDecrypt(t *testing.T) { + testCases := []struct { + key []byte + ciphertextB64 string + expectedPlaintext string + }{ + { + key: []byte("AES256Key-32Characters1234567890"), + ciphertextB64: "M2ov9hWMQ52v1S/zigwX3bJt4cVCV02uiRm/grKqN/rZxNkJrD7vK4Ii0g==", + expectedPlaintext: "foo bar baz quz", + }, + { + key: []byte("AES256Key-32Characters1234567890"), + ciphertextB64: "M4csFKUIUbD1FBEzLgHjscoKgN0lhMGJ0n2nKWiCkE/qSKlRP7kS", + expectedPlaintext: "foo\n1\nbar\n2", + }, + { + key: []byte("AES256Key-32Characters1234567890"), + ciphertextB64: "pe/fnw73MR1clmVIlRSJ5gDwBdnPly/DF7DsR5dJVz4dHZlv0b10WzvJEGOCHZEr+Q==", + expectedPlaintext: "föo\nbār\nbåz & qūz", + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("key %s ciphertext %s", tc.key, tc.ciphertextB64), func(t *testing.T) { + plaintext, err := AesGcmDecrypt(tc.key, tc.ciphertextB64) + if err != nil { + t.Fatal(errors.Wrap(err, "performing decryption")) + } + + testutils.AssertDeepEqual(t, plaintext, []byte(tc.expectedPlaintext), "plaintext mismatch") + }) + } +} diff --git a/infra/main.go b/infra/main.go index 28bf87e7..914f112f 100644 --- a/infra/main.go +++ b/infra/main.go @@ -3,6 +3,7 @@ package infra import ( "database/sql" + "encoding/base64" "fmt" "os" "os/user" @@ -27,21 +28,29 @@ var ( SystemLastMaxUSN = "last_max_usn" // SystemLastUpgrade is the timestamp at which the system more recently checked for an upgrade SystemLastUpgrade = "last_upgrade" + // SystemCipherKey is the encryption key + SystemCipherKey = "enc_key" + // SystemSessionKey is the session key + SystemSessionKey = "session_token" + // SystemSessionKeyExpiry is the timestamp at which the session key will expire + SystemSessionKeyExpiry = "session_token_expiry" ) // DnoteCtx is a context holding the information of the current runtime type DnoteCtx struct { - HomeDir string - DnoteDir string - APIEndpoint string - Version string - DB *sql.DB + HomeDir string + DnoteDir string + APIEndpoint string + Version string + DB *DB + SessionKey string + SessionKeyExpiry int64 + CipherKey []byte } // Config holds dnote configuration type Config struct { Editor string - APIKey string } // NewCtx returns a new dnote context @@ -53,7 +62,7 @@ func NewCtx(apiEndpoint, versionTag string) (DnoteCtx, error) { dnoteDir := getDnoteDir(homeDir) dnoteDBPath := fmt.Sprintf("%s/dnote.db", dnoteDir) - db, err := sql.Open("sqlite3", dnoteDBPath) + db, err := OpenDB(dnoteDBPath) if err != nil { return DnoteCtx{}, errors.Wrap(err, "conntecting to db") } @@ -69,6 +78,45 @@ func NewCtx(apiEndpoint, versionTag string) (DnoteCtx, error) { return ret, nil } +// SetupCtx populates context and returns a new context +func SetupCtx(ctx DnoteCtx) (DnoteCtx, error) { + db := ctx.DB + + var sessionKey, cipherKeyB64 string + var sessionKeyExpiry int64 + + err := db.QueryRow("SELECT value FROM system WHERE key = ?", SystemSessionKey).Scan(&sessionKey) + if err != nil && err != sql.ErrNoRows { + return ctx, errors.Wrap(err, "finding sesison key") + } + err = db.QueryRow("SELECT value FROM system WHERE key = ?", SystemCipherKey).Scan(&cipherKeyB64) + if err != nil && err != sql.ErrNoRows { + return ctx, errors.Wrap(err, "finding sesison key") + } + err = db.QueryRow("SELECT value FROM system WHERE key = ?", SystemSessionKeyExpiry).Scan(&sessionKeyExpiry) + if err != nil && err != sql.ErrNoRows { + return ctx, errors.Wrap(err, "finding sesison key expiry") + } + + cipherKey, err := base64.StdEncoding.DecodeString(cipherKeyB64) + if err != nil { + return ctx, errors.Wrap(err, "decoding cipherKey from base64") + } + + ret := DnoteCtx{ + HomeDir: ctx.HomeDir, + DnoteDir: ctx.DnoteDir, + APIEndpoint: ctx.APIEndpoint, + Version: ctx.Version, + DB: ctx.DB, + SessionKey: sessionKey, + SessionKeyExpiry: sessionKeyExpiry, + CipherKey: cipherKey, + } + + return ret, nil +} + func getDnoteDir(homeDir string) string { var ret string diff --git a/infra/sql.go b/infra/sql.go new file mode 100644 index 00000000..7ed8326d --- /dev/null +++ b/infra/sql.go @@ -0,0 +1,120 @@ +package infra + +import ( + "database/sql" + + "github.com/pkg/errors" +) + +// SQLCommon is the minimal interface required by a db connection +type SQLCommon interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +// sqlDb is an interface implemented by *sql.DB +type sqlDb interface { + Begin() (*sql.Tx, error) +} + +// sqlTx is an interface implemented by *sql.Tx +type sqlTx interface { + Commit() error + Rollback() error +} + +// DB contains information about the current database connection +type DB struct { + Conn SQLCommon +} + +// OpenDB initializes a new connection to the sqlite database +func OpenDB(dbPath string) (*DB, error) { + dbConn, err := sql.Open("sqlite3", dbPath) + if err != nil { + return nil, errors.Wrap(err, "opening db connection") + } + + // Send a ping to ensure that the connection is established + // if err := dbConn.Ping(); err != nil { + // dbConn.Close() + // return nil, errors.Wrap(err, "ping") + // } + + db := &DB{ + Conn: dbConn, + } + + return db, nil +} + +// Begin begins a transaction +func (d *DB) Begin() (*DB, error) { + if db, ok := d.Conn.(sqlDb); ok && db != nil { + tx, err := db.Begin() + if err != nil { + return nil, err + } + + return &DB{Conn: tx}, nil + } + + return nil, errors.New("can't start transaction") +} + +// Commit commits a transaction +func (d *DB) Commit() error { + if db, ok := d.Conn.(sqlTx); ok && db != nil { + if err := db.Commit(); err != nil { + return err + } + } + + return errors.New("invalid transaction") +} + +// Rollback rolls back a transaction +func (d *DB) Rollback() error { + if db, ok := d.Conn.(sqlTx); ok && db != nil { + if err := db.Rollback(); err != nil { + return err + } + } + + return errors.New("invalid transaction") +} + +// Exec executes a sql +func (d *DB) Exec(query string, values ...interface{}) (sql.Result, error) { + return d.Conn.Exec(query, values...) +} + +// Prepare prepares a sql +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return d.Conn.Prepare(query) +} + +// Query queries rows +func (d *DB) Query(query string, values ...interface{}) (*sql.Rows, error) { + return d.Conn.Query(query, values...) +} + +// QueryRow queries a row +func (d *DB) QueryRow(query string, values ...interface{}) *sql.Row { + return d.Conn.QueryRow(query, values...) +} + +type closer interface { + Close() error +} + +// Close closes a db connection +func (d *DB) Close() error { + if db, ok := d.Conn.(closer); ok { + return db.Close() + } + + return errors.New("can't close db") +} diff --git a/log/log.go b/log/log.go index 5cf648dc..24a25d6d 100644 --- a/log/log.go +++ b/log/log.go @@ -16,14 +16,14 @@ var ( // ColorBlue is a blue foreground color ColorBlue = color.New(color.FgBlue) // ColorGray is a gray foreground color - ColorGray = color.New(color.FgWhite) + ColorGray = color.New(color.FgHiBlack) ) var indent = " " // Info prints information func Info(msg string) { - fmt.Fprintf(color.Output, "%s%s %s\n", indent, ColorBlue.Sprint("•"), msg) + fmt.Fprintf(color.Output, "%s%s %s", indent, ColorBlue.Sprint("•"), msg) } // Infof prints information with optional format verbs @@ -58,7 +58,7 @@ func Warnf(msg string, v ...interface{}) { // Error prints an error message func Error(msg string) { - fmt.Fprintf(color.Output, "%s%s %s\n", indent, ColorRed.Sprint("⨯"), msg) + fmt.Fprintf(color.Output, "%s%s %s", indent, ColorRed.Sprint("⨯"), msg) } // Errorf prints an error message with optional format verbs @@ -71,6 +71,21 @@ func Printf(msg string, v ...interface{}) { fmt.Fprintf(color.Output, "%s%s %s", indent, ColorGray.Sprint("•"), fmt.Sprintf(msg, v...)) } +// Askf prints an question with optional format verbs. The leading symbol differs in color depending +// on whether the input is masked. +func Askf(msg string, masked bool, v ...interface{}) { + symbolChar := "[?]" + + var symbol string + if masked { + symbol = ColorGray.Sprintf(symbolChar) + } else { + symbol = ColorGreen.Sprintf(symbolChar) + } + + fmt.Fprintf(color.Output, "%s%s %s: ", indent, symbol, fmt.Sprintf(msg, v...)) +} + // Debug prints to the console if DNOTE_DEBUG is set func Debug(msg string, v ...interface{}) { if os.Getenv("DNOTE_DEBUG") == "1" { diff --git a/main.go b/main.go index 50484f81..2f601d7c 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "github.com/dnote/cli/cmd/edit" "github.com/dnote/cli/cmd/find" "github.com/dnote/cli/cmd/login" + "github.com/dnote/cli/cmd/logout" "github.com/dnote/cli/cmd/ls" "github.com/dnote/cli/cmd/remove" "github.com/dnote/cli/cmd/sync" @@ -37,9 +38,15 @@ func main() { panic(errors.Wrap(err, "preparing dnote run")) } + ctx, err = infra.SetupCtx(ctx) + if err != nil { + panic(errors.Wrap(err, "setting up context")) + } + root.Register(remove.NewCmd(ctx)) root.Register(edit.NewCmd(ctx)) root.Register(login.NewCmd(ctx)) + root.Register(logout.NewCmd(ctx)) root.Register(add.NewCmd(ctx)) root.Register(ls.NewCmd(ctx)) root.Register(sync.NewCmd(ctx)) diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index 9c937163..f5eb7250 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -1,7 +1,6 @@ package migrate import ( - "database/sql" "encoding/json" "fmt" "net/http" @@ -39,13 +38,13 @@ func TestExecute_bump_schema(t *testing.T) { m1 := migration{ name: "noop", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { return nil }, } m2 := migration{ name: "noop", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { return nil }, } @@ -97,28 +96,28 @@ func TestRun_nonfresh(t *testing.T) { sequence := []migration{ migration{ name: "v1", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v1 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v1") return nil }, }, migration{ name: "v2", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v2 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v2") return nil }, }, migration{ name: "v3", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v3 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v3") return nil }, }, migration{ name: "v4", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v4 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v4") return nil }, @@ -176,21 +175,21 @@ func TestRun_fresh(t *testing.T) { sequence := []migration{ migration{ name: "v1", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v1 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v1") return nil }, }, migration{ name: "v2", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v2 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v2") return nil }, }, migration{ name: "v3", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v3 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v3") return nil }, @@ -250,21 +249,21 @@ func TestRun_up_to_date(t *testing.T) { sequence := []migration{ migration{ name: "v1", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v1 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v1") return nil }, }, migration{ name: "v2", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v2 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v2") return nil }, }, migration{ name: "v3", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, db *infra.DB) error { testutils.MustExec(t, "marking v3 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v3") return nil }, @@ -876,6 +875,7 @@ func TestLocalMigration9(t *testing.T) { func TestRemoteMigration1(t *testing.T) { // set up ctx := testutils.InitEnv(t, "../tmp", "./fixtures/remote-1-pre-schema.sql", false) + testutils.Login(t, &ctx) defer testutils.TeardownEnv(ctx) JSBookUUID := "existing-js-book-uuid" @@ -914,14 +914,13 @@ func TestRemoteMigration1(t *testing.T) { ctx.APIEndpoint = server.URL - confStr := fmt.Sprintf("apikey: mock_api_key") - testutils.WriteFile(ctx, []byte(confStr), "dnoterc") - db := ctx.DB testutils.MustExec(t, "inserting js book", db, "INSERT INTO books (uuid, label) VALUES (?, ?)", JSBookUUID, "js") testutils.MustExec(t, "inserting css book", db, "INSERT INTO books (uuid, label) VALUES (?, ?)", CSSBookUUID, "css") testutils.MustExec(t, "inserting linux book", db, "INSERT INTO books (uuid, label) VALUES (?, ?)", linuxBookUUID, "linux") + testutils.MustExec(t, "inserting sessionKey", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemSessionKey, "someSessionKey") + testutils.MustExec(t, "inserting sessionKeyExpiry", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemSessionKeyExpiry, time.Now().Add(24*time.Hour).Unix()) tx, err := db.Begin() if err != nil { diff --git a/migrate/migrations.go b/migrate/migrations.go index 100350b2..dc964f8b 100644 --- a/migrate/migrations.go +++ b/migrate/migrations.go @@ -7,7 +7,6 @@ import ( "github.com/dnote/actions" "github.com/dnote/cli/client" - "github.com/dnote/cli/core" "github.com/dnote/cli/infra" "github.com/dnote/cli/log" "github.com/pkg/errors" @@ -15,12 +14,12 @@ import ( type migration struct { name string - run func(ctx infra.DnoteCtx, tx *sql.Tx) error + run func(ctx infra.DnoteCtx, tx *infra.DB) error } var lm1 = migration{ name: "upgrade-edit-note-from-v1-to-v3", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { rows, err := tx.Query("SELECT uuid, data FROM actions WHERE type = ? AND schema = ?", "edit_note", 1) if err != nil { return errors.Wrap(err, "querying rows") @@ -68,7 +67,7 @@ var lm1 = migration{ var lm2 = migration{ name: "upgrade-edit-note-from-v2-to-v3", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { rows, err := tx.Query("SELECT uuid, data FROM actions WHERE type = ? AND schema = ?", "edit_note", 2) if err != nil { return errors.Wrap(err, "querying rows") @@ -113,7 +112,7 @@ var lm2 = migration{ var lm3 = migration{ name: "upgrade-remove-note-from-v1-to-v2", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { rows, err := tx.Query("SELECT uuid, data FROM actions WHERE type = ? AND schema = ?", "remove_note", 1) if err != nil { return errors.Wrap(err, "querying rows") @@ -155,7 +154,7 @@ var lm3 = migration{ var lm4 = migration{ name: "add-dirty-usn-deleted-to-notes-and-books", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { _, err := tx.Exec("ALTER TABLE books ADD COLUMN dirty bool DEFAULT false") if err != nil { return errors.Wrap(err, "adding dirty column to books") @@ -192,7 +191,7 @@ var lm4 = migration{ var lm5 = migration{ name: "mark-action-targets-dirty", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { rows, err := tx.Query("SELECT uuid, data, type FROM actions") if err != nil { return errors.Wrap(err, "querying rows") @@ -254,7 +253,7 @@ var lm5 = migration{ var lm6 = migration{ name: "drop-actions", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { _, err := tx.Exec("DROP TABLE actions;") if err != nil { return errors.Wrap(err, "dropping the actions table") @@ -266,7 +265,7 @@ var lm6 = migration{ var lm7 = migration{ name: "resolve-conflicts-with-reserved-book-names", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { migrateBook := func(name string) error { var uuid string @@ -313,7 +312,7 @@ var lm7 = migration{ var lm8 = migration{ name: "drop-note-id-and-rename-content-to-body", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { _, err := tx.Exec(`CREATE TABLE notes_tmp ( uuid text NOT NULL, @@ -352,7 +351,7 @@ var lm8 = migration{ var lm9 = migration{ name: "create-fts-index", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { _, err := tx.Exec(`CREATE VIRTUAL TABLE IF NOT EXISTS note_fts USING fts5(content=notes, body, tokenize="porter unicode61 categories 'L* N* Co Ps Pe'");`) if err != nil { return errors.Wrap(err, "creating note_fts") @@ -388,16 +387,13 @@ var lm9 = migration{ var rm1 = migration{ name: "sync-book-uuids-from-server", - run: func(ctx infra.DnoteCtx, tx *sql.Tx) error { - config, err := core.ReadConfig(ctx) - if err != nil { - return errors.Wrap(err, "reading the config") - } - if config.APIKey == "" { - return errors.New("login required") + run: func(ctx infra.DnoteCtx, tx *infra.DB) error { + sessionKey := ctx.SessionKey + if sessionKey == "" { + return errors.New("not logged in") } - resp, err := client.GetBooks(ctx, config.APIKey) + resp, err := client.GetBooks(ctx, sessionKey) if err != nil { return errors.Wrap(err, "getting books from the server") } diff --git a/scripts/test.sh b/scripts/test.sh index 6e322de9..7c05bef5 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1,12 +1,17 @@ #!/bin/bash - # run_server_test.sh runs server test files sequentially # https://stackoverflow.com/questions/23715302/go-how-to-run-tests-for-multiple-packages +set -e + # clear tmp dir in case not properly torn down -rm -rf ./tmp +rm -rf $GOPATH/src/github.com/dnote/cli/tmp # run test +pushd $GOPATH/src/github.com/dnote/cli + go test ./... \ -p 1\ --tags "fts5" + +popd diff --git a/testutils/main.go b/testutils/main.go index 7ae009f1..5179fe13 100644 --- a/testutils/main.go +++ b/testutils/main.go @@ -14,6 +14,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/dnote/cli/infra" "github.com/dnote/cli/utils" @@ -21,13 +22,8 @@ import ( ) // InitEnv sets up a test env and returns a new dnote context -func InitEnv(t *testing.T, relPath string, relFixturePath string, migrated bool) infra.DnoteCtx { - path, err := filepath.Abs(relPath) - if err != nil { - t.Fatal(errors.Wrap(err, "pasrsing path")) - } - - os.Setenv("DNOTE_HOME_DIR", path) +func InitEnv(t *testing.T, dnotehomePath string, fixturePath string, migrated bool) infra.DnoteCtx { + os.Setenv("DNOTE_HOME_DIR", dnotehomePath) ctx, err := infra.NewCtx("", "") if err != nil { t.Fatal(errors.Wrap(err, "getting new ctx")) @@ -39,7 +35,7 @@ func InitEnv(t *testing.T, relPath string, relFixturePath string, migrated bool) } // set up db - b := ReadFileAbs(relFixturePath) + b := ReadFileAbs(fixturePath) setupSQL := string(b) db := ctx.DB @@ -61,6 +57,19 @@ func InitEnv(t *testing.T, relPath string, relFixturePath string, migrated bool) return ctx } +// Login simulates a logged in user by inserting credentials in the local database +func Login(t *testing.T, ctx *infra.DnoteCtx) { + db := ctx.DB + + MustExec(t, "inserting sessionKey", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemSessionKey, "someSessionKey") + MustExec(t, "inserting sessionKeyExpiry", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemSessionKeyExpiry, time.Now().Add(24*time.Hour).Unix()) + MustExec(t, "inserting cipherKey", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemCipherKey, "QUVTMjU2S2V5LTMyQ2hhcmFjdGVyczEyMzQ1Njc4OTA=") + + ctx.SessionKey = "someSessionKey" + ctx.SessionKeyExpiry = time.Now().Add(24 * time.Hour).Unix() + ctx.CipherKey = []byte("AES256Key-32Characters1234567890") +} + // TeardownEnv cleans up the test env represented by the given context func TeardownEnv(ctx infra.DnoteCtx) { ctx.DB.Close() @@ -210,7 +219,7 @@ func IsEqualJSON(s1, s2 []byte) (bool, error) { } // MustExec executes the given SQL query and fails a test if an error occurs -func MustExec(t *testing.T, message string, db *sql.DB, query string, args ...interface{}) sql.Result { +func MustExec(t *testing.T, message string, db *infra.DB, query string, args ...interface{}) sql.Result { result, err := db.Exec(query, args...) if err != nil { t.Fatal(errors.Wrap(errors.Wrap(err, "executing sql"), message)) @@ -309,7 +318,7 @@ func WaitDnoteCmd(t *testing.T, ctx infra.DnoteCtx, runFunc func(io.WriteCloser) func UserConfirm(stdin io.WriteCloser) error { // confirm if _, err := io.WriteString(stdin, "y\n"); err != nil { - return errors.Wrap(err, "confirming deletion") + return errors.Wrap(err, "indicating confirmation in stdin") } return nil diff --git a/utils/utils.go b/utils/utils.go index 14593b3f..f33da5fd 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -9,11 +9,13 @@ import ( "os" "path/filepath" "strings" + "syscall" "github.com/dnote/cli/infra" "github.com/dnote/cli/log" "github.com/pkg/errors" "github.com/satori/go.uuid" + "golang.org/x/crypto/ssh/terminal" ) // GenerateUUID returns a uid @@ -28,7 +30,38 @@ func getInput() (string, error) { return "", errors.Wrap(err, "reading stdin") } - return input, nil + return strings.Trim(input, "\r\n"), nil +} + +// PromptInput prompts the user input and saves the result to the destination +func PromptInput(message string, dest *string) error { + log.Askf(message, false) + + input, err := getInput() + if err != nil { + return errors.Wrap(err, "getting user input") + } + + *dest = input + + return nil +} + +// PromptPassword prompts the user input a password and saves the result to the destination. +// The input is masked, meaning it is not echoed on the terminal. +func PromptPassword(message string, dest *string) error { + log.Askf(message, true) + + password, err := terminal.ReadPassword(syscall.Stdin) + if err != nil { + return errors.Wrap(err, "getting user input") + } + + fmt.Println("") + + *dest = string(password) + + return nil } // AskConfirmation prompts for user input to confirm a choice @@ -40,17 +73,17 @@ func AskConfirmation(question string, optimistic bool) (bool, error) { choices = "(y/N)" } - log.Printf("%s %s: ", question, choices) + message := fmt.Sprintf("%s %s", question, choices) - res, err := getInput() - if err != nil { + var input string + if err := PromptInput(message, &input); err != nil { return false, errors.Wrap(err, "Failed to get user input") } - confirmed := res == "y\n" || res == "y\r\n" + confirmed := input == "y" if optimistic { - confirmed = confirmed || res == "\n" || res == "\r\n" + confirmed = confirmed || input == "" } return confirmed, nil @@ -109,18 +142,48 @@ func CopyDir(src, dest string) error { return nil } -// DoAuthorizedReq does a http request to the given path in the api endpoint as a user, -// with the appropriate headers. The given path should include the preceding slash. -func DoAuthorizedReq(ctx infra.DnoteCtx, apiKey, method, path, body string) (*http.Response, error) { +func getReq(ctx infra.DnoteCtx, path, method, body string) (*http.Request, error) { endpoint := fmt.Sprintf("%s%s", ctx.APIEndpoint, path) req, err := http.NewRequest(method, endpoint, strings.NewReader(body)) if err != nil { return nil, errors.Wrap(err, "constructing http request") } - req.Header.Set("Authorization", apiKey) req.Header.Set("CLI-Version", ctx.Version) + return req, nil +} + +// DoAuthorizedReq does a http request to the given path in the api endpoint as a user, +// with the appropriate headers. The given path should include the preceding slash. +func DoAuthorizedReq(ctx infra.DnoteCtx, hc http.Client, method, path, body string) (*http.Response, error) { + if ctx.SessionKey == "" { + return nil, errors.New("no session key found") + } + + req, err := getReq(ctx, path, method, body) + if err != nil { + return nil, errors.Wrap(err, "getting request") + } + + credential := fmt.Sprintf("Bearer %s", ctx.SessionKey) + req.Header.Set("Authorization", credential) + + res, err := hc.Do(req) + if err != nil { + return res, errors.Wrap(err, "making http request") + } + + return res, nil +} + +// DoReq does a http request to the given path in the api endpoint +func DoReq(ctx infra.DnoteCtx, method, path, body string) (*http.Response, error) { + req, err := getReq(ctx, path, method, body) + if err != nil { + return nil, errors.Wrap(err, "getting request") + } + hc := http.Client{} res, err := hc.Do(req) if err != nil {