Encryption (#165)

* Implement login and logout

* Add encrypt util

* Use v2

* Abstract common interface between db and tx

* Fix test

* Check login

* Fix test

* Fix login

* Fix path

* Improve test

* Fix output
This commit is contained in:
Sung Won Cho 2019-03-31 16:23:46 +11:00 committed by GitHub
commit 73526a943c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1445 additions and 244 deletions

View file

@ -5,8 +5,9 @@
- [edit](#dnote-edit)
- [remove](#dnote-remove)
- [find](#dnote-find)
- [login](#dnote-login)
- [sync](#dnote-sync)
- [login](#dnote-login)
- [logout](#dnote-logout)
## dnote add
@ -87,7 +88,7 @@ dnote find "merge sort" -b algorithm
## dnote sync
_Dnote Cloud only_
_Dnote Pro only_
_alias: s_
@ -95,6 +96,12 @@ Sync notes with Dnote cloud
## dnote login
_Dnote Cloud only_
_Dnote Pro only_
Start a login prompt
## dnote logout
_Dnote Pro only_
Log out of Dnote

20
Gopkg.lock generated
View file

@ -97,11 +97,26 @@
revision = "298182f68c66c05229eb03ac171abe6e309ee79a"
version = "v1.0.3"
[[projects]]
branch = "master"
digest = "1:d0f4eb7abce3fbd3f0dcbbc03ffe18464846afd34c815928d2ae11c1e5aded04"
name = "golang.org/x/crypto"
packages = [
"hkdf",
"pbkdf2",
"ssh/terminal",
]
pruneopts = ""
revision = "ffb98f73852f696ea2bb21a617a5c4b3e067a439"
[[projects]]
branch = "master"
digest = "1:7e3b61f51ebcb58b3894928ed7c63aae68820dec1dd57166e5d6e65ef2868f40"
name = "golang.org/x/sys"
packages = ["unix"]
packages = [
"unix",
"windows",
]
pruneopts = ""
revision = "b90733256f2e882e81d52f9126de08df5615afd9"
@ -124,6 +139,9 @@
"github.com/pkg/errors",
"github.com/satori/go.uuid",
"github.com/spf13/cobra",
"golang.org/x/crypto/hkdf",
"golang.org/x/crypto/pbkdf2",
"golang.org/x/crypto/ssh/terminal",
"gopkg.in/yaml.v2",
]
solver-name = "gps-cdcl"

View file

@ -30,7 +30,7 @@
[[constraint]]
name = "github.com/spf13/cobra"
version = "0.0.1"
version = "0.0.3"
[[constraint]]
branch = "v2"

View file

@ -12,11 +12,15 @@ import (
"strings"
"time"
"github.com/dnote/cli/crypt"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/utils"
"github.com/pkg/errors"
)
// ErrInvalidLogin is an error for invalid credentials for login
var ErrInvalidLogin = errors.New("wrong credentials")
// GetSyncStateResp is the response get sync state endpoint
type GetSyncStateResp struct {
FullSyncBefore int `json:"full_sync_before"`
@ -25,10 +29,11 @@ type GetSyncStateResp struct {
}
// GetSyncState gets the sync state response from the server
func GetSyncState(apiKey string, ctx infra.DnoteCtx) (GetSyncStateResp, error) {
func GetSyncState(ctx infra.DnoteCtx) (GetSyncStateResp, error) {
var ret GetSyncStateResp
res, err := utils.DoAuthorizedReq(ctx, apiKey, "GET", "/v1/sync/state", "")
hc := http.Client{}
res, err := utils.DoAuthorizedReq(ctx, hc, "GET", "/v1/sync/state", "")
if err != nil {
return ret, errors.Wrap(err, "constructing http request")
}
@ -89,13 +94,14 @@ type GetSyncFragmentResp struct {
}
// GetSyncFragment gets a sync fragment response from the server
func GetSyncFragment(ctx infra.DnoteCtx, apiKey string, afterUSN int) (GetSyncFragmentResp, error) {
func GetSyncFragment(ctx infra.DnoteCtx, afterUSN int) (GetSyncFragmentResp, error) {
v := url.Values{}
v.Set("after_usn", strconv.Itoa(afterUSN))
queryStr := v.Encode()
path := fmt.Sprintf("/v1/sync/fragment?%s", queryStr)
res, err := utils.DoAuthorizedReq(ctx, apiKey, "GET", path, "")
hc := http.Client{}
res, err := utils.DoAuthorizedReq(ctx, hc, "GET", path, "")
body, err := ioutil.ReadAll(res.Body)
if err != nil {
@ -148,25 +154,31 @@ func checkRespErr(res *http.Response) (bool, string, error) {
}
// CreateBook creates a new book in the server
func CreateBook(ctx infra.DnoteCtx, apiKey, label string) (CreateBookResp, error) {
func CreateBook(ctx infra.DnoteCtx, label string) (CreateBookResp, error) {
encLabel, err := crypt.AesGcmEncrypt(ctx.CipherKey, []byte(label))
if err != nil {
return CreateBookResp{}, errors.Wrap(err, "encrypting the label")
}
payload := CreateBookPayload{
Name: label,
Name: encLabel,
}
b, err := json.Marshal(payload)
if err != nil {
return CreateBookResp{}, errors.Wrap(err, "marshaling payload")
}
res, err := utils.DoAuthorizedReq(ctx, apiKey, "POST", "/v1/books", string(b))
hc := http.Client{}
res, err := utils.DoAuthorizedReq(ctx, hc, "POST", "/v2/books", string(b))
if err != nil {
return CreateBookResp{}, errors.Wrap(err, "posting a book to the server")
}
ok, message, err := checkRespErr(res)
hasErr, message, err := checkRespErr(res)
if err != nil {
return CreateBookResp{}, errors.Wrap(err, "checking repsonse error")
}
if ok {
if hasErr {
return CreateBookResp{}, errors.New(message)
}
@ -188,26 +200,32 @@ type UpdateBookResp struct {
}
// UpdateBook updates a book in the server
func UpdateBook(ctx infra.DnoteCtx, apiKey, label, uuid string) (UpdateBookResp, error) {
func UpdateBook(ctx infra.DnoteCtx, label, uuid string) (UpdateBookResp, error) {
encName, err := crypt.AesGcmEncrypt(ctx.CipherKey, []byte(label))
if err != nil {
return UpdateBookResp{}, errors.Wrap(err, "encrypting the content")
}
payload := updateBookPayload{
Name: &label,
Name: &encName,
}
b, err := json.Marshal(payload)
if err != nil {
return UpdateBookResp{}, errors.Wrap(err, "marshaling payload")
}
hc := http.Client{}
endpoint := fmt.Sprintf("/v1/books/%s", uuid)
res, err := utils.DoAuthorizedReq(ctx, apiKey, "PATCH", endpoint, string(b))
res, err := utils.DoAuthorizedReq(ctx, hc, "PATCH", endpoint, string(b))
if err != nil {
return UpdateBookResp{}, errors.Wrap(err, "posting a book to the server")
}
ok, message, err := checkRespErr(res)
hasErr, message, err := checkRespErr(res)
if err != nil {
return UpdateBookResp{}, errors.Wrap(err, "checking repsonse error")
}
if ok {
if hasErr {
return UpdateBookResp{}, errors.New(message)
}
@ -226,18 +244,19 @@ type DeleteBookResp struct {
}
// DeleteBook deletes a book in the server
func DeleteBook(ctx infra.DnoteCtx, apiKey, uuid string) (DeleteBookResp, error) {
func DeleteBook(ctx infra.DnoteCtx, uuid string) (DeleteBookResp, error) {
hc := http.Client{}
endpoint := fmt.Sprintf("/v1/books/%s", uuid)
res, err := utils.DoAuthorizedReq(ctx, apiKey, "DELETE", endpoint, "")
res, err := utils.DoAuthorizedReq(ctx, hc, "DELETE", endpoint, "")
if err != nil {
return DeleteBookResp{}, errors.Wrap(err, "deleting a book in the server")
}
ok, message, err := checkRespErr(res)
hasErr, message, err := checkRespErr(res)
if err != nil {
return DeleteBookResp{}, errors.Wrap(err, "checking repsonse error")
}
if ok {
if hasErr {
return DeleteBookResp{}, errors.New(message)
}
@ -283,26 +302,32 @@ type RespNote struct {
}
// CreateNote creates a note in the server
func CreateNote(ctx infra.DnoteCtx, apiKey, bookUUID, content string) (CreateNoteResp, error) {
func CreateNote(ctx infra.DnoteCtx, bookUUID, content string) (CreateNoteResp, error) {
encBody, err := crypt.AesGcmEncrypt(ctx.CipherKey, []byte(content))
if err != nil {
return CreateNoteResp{}, errors.Wrap(err, "encrypting the content")
}
payload := CreateNotePayload{
BookUUID: bookUUID,
Body: content,
Body: encBody,
}
b, err := json.Marshal(payload)
if err != nil {
return CreateNoteResp{}, errors.Wrap(err, "marshaling payload")
}
res, err := utils.DoAuthorizedReq(ctx, apiKey, "POST", "/v1/notes", string(b))
hc := http.Client{}
res, err := utils.DoAuthorizedReq(ctx, hc, "POST", "/v2/notes", string(b))
if err != nil {
return CreateNoteResp{}, errors.Wrap(err, "posting a book to the server")
}
ok, message, err := checkRespErr(res)
hasErr, message, err := checkRespErr(res)
if err != nil {
return CreateNoteResp{}, errors.Wrap(err, "checking repsonse error")
}
if ok {
if hasErr {
return CreateNoteResp{}, errors.New(message)
}
@ -327,10 +352,15 @@ type UpdateNoteResp struct {
}
// UpdateNote updates a note in the server
func UpdateNote(ctx infra.DnoteCtx, apiKey, uuid, bookUUID, content string, public bool) (UpdateNoteResp, error) {
func UpdateNote(ctx infra.DnoteCtx, uuid, bookUUID, content string, public bool) (UpdateNoteResp, error) {
encBody, err := crypt.AesGcmEncrypt(ctx.CipherKey, []byte(content))
if err != nil {
return UpdateNoteResp{}, errors.Wrap(err, "encrypting the content")
}
payload := updateNotePayload{
BookUUID: &bookUUID,
Body: &content,
Body: &encBody,
Public: &public,
}
b, err := json.Marshal(payload)
@ -338,17 +368,18 @@ func UpdateNote(ctx infra.DnoteCtx, apiKey, uuid, bookUUID, content string, publ
return UpdateNoteResp{}, errors.Wrap(err, "marshaling payload")
}
hc := http.Client{}
endpoint := fmt.Sprintf("/v1/notes/%s", uuid)
res, err := utils.DoAuthorizedReq(ctx, apiKey, "PATCH", endpoint, string(b))
res, err := utils.DoAuthorizedReq(ctx, hc, "PATCH", endpoint, string(b))
if err != nil {
return UpdateNoteResp{}, errors.Wrap(err, "patching a note to the server")
}
ok, message, err := checkRespErr(res)
hasErr, message, err := checkRespErr(res)
if err != nil {
return UpdateNoteResp{}, errors.Wrap(err, "checking repsonse error")
}
if ok {
if hasErr {
return UpdateNoteResp{}, errors.New(message)
}
@ -367,18 +398,19 @@ type DeleteNoteResp struct {
}
// DeleteNote removes a note in the server
func DeleteNote(ctx infra.DnoteCtx, apiKey, uuid string) (DeleteNoteResp, error) {
func DeleteNote(ctx infra.DnoteCtx, uuid string) (DeleteNoteResp, error) {
hc := http.Client{}
endpoint := fmt.Sprintf("/v1/notes/%s", uuid)
res, err := utils.DoAuthorizedReq(ctx, apiKey, "DELETE", endpoint, "")
res, err := utils.DoAuthorizedReq(ctx, hc, "DELETE", endpoint, "")
if err != nil {
return DeleteNoteResp{}, errors.Wrap(err, "patching a note to the server")
}
ok, message, err := checkRespErr(res)
hasErr, message, err := checkRespErr(res)
if err != nil {
return DeleteNoteResp{}, errors.Wrap(err, "checking repsonse error")
}
if ok {
if hasErr {
return DeleteNoteResp{}, errors.New(message)
}
@ -397,17 +429,18 @@ type GetBooksResp []struct {
}
// GetBooks gets books from the server
func GetBooks(ctx infra.DnoteCtx, apiKey string) (GetBooksResp, error) {
res, err := utils.DoAuthorizedReq(ctx, apiKey, "GET", "/v1/books", "")
func GetBooks(ctx infra.DnoteCtx, sessionKey string) (GetBooksResp, error) {
hc := http.Client{}
res, err := utils.DoAuthorizedReq(ctx, hc, "GET", "/v1/books", "")
if err != nil {
return GetBooksResp{}, errors.Wrap(err, "making http request")
}
ok, message, err := checkRespErr(res)
hasErr, message, err := checkRespErr(res)
if err != nil {
return GetBooksResp{}, errors.Wrap(err, "checking repsonse error")
}
if ok {
if hasErr {
return GetBooksResp{}, errors.New(message)
}
@ -418,3 +451,104 @@ func GetBooks(ctx infra.DnoteCtx, apiKey string) (GetBooksResp, error) {
return resp, nil
}
// PresigninResponse is a reponse from /v1/presignin endpoint
type PresigninResponse struct {
Iteration int `json:"iteration"`
}
// GetPresignin gets presignin credentials
func GetPresignin(ctx infra.DnoteCtx, email string) (PresigninResponse, error) {
res, err := utils.DoReq(ctx, "GET", fmt.Sprintf("/v1/presignin?email=%s", email), "")
if err != nil {
return PresigninResponse{}, errors.Wrap(err, "making http request")
}
hasErr, message, err := checkRespErr(res)
if err != nil {
return PresigninResponse{}, errors.Wrap(err, "checking repsonse error")
}
if hasErr {
return PresigninResponse{}, errors.New(message)
}
var resp PresigninResponse
if err := json.NewDecoder(res.Body).Decode(&resp); err != nil {
return PresigninResponse{}, errors.Wrap(err, "decoding payload")
}
return resp, nil
}
// SigninPayload is a payload for /v1/signin
type SigninPayload struct {
Email string `json:"email"`
AuthKey string `json:"auth_key"`
}
// SigninResponse is a response from /v1/signin endpoint
type SigninResponse struct {
Key string `json:"key"`
ExpiresAt int64 `json:"expires_at"`
CipherKeyEnc string `json:"cipher_key_enc"`
}
// Signin requests a session token
func Signin(ctx infra.DnoteCtx, email, authKey string) (SigninResponse, error) {
payload := SigninPayload{
Email: email,
AuthKey: authKey,
}
b, err := json.Marshal(payload)
if err != nil {
return SigninResponse{}, errors.Wrap(err, "marshaling payload")
}
res, err := utils.DoReq(ctx, "POST", "/v1/signin", string(b))
if err != nil {
return SigninResponse{}, errors.Wrap(err, "making http request")
}
if res.StatusCode == http.StatusUnauthorized {
return SigninResponse{}, ErrInvalidLogin
}
hasErr, message, err := checkRespErr(res)
if err != nil {
return SigninResponse{}, errors.Wrap(err, "checking repsonse error")
}
if hasErr {
return SigninResponse{}, errors.New(message)
}
var resp SigninResponse
if err := json.NewDecoder(res.Body).Decode(&resp); err != nil {
return SigninResponse{}, errors.Wrap(err, "decoding payload")
}
return resp, nil
}
// Signout deletes a user session on the server side
func Signout(ctx infra.DnoteCtx, sessionKey string) error {
hc := http.Client{
// No need to follow redirect
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
res, err := utils.DoAuthorizedReq(ctx, hc, "POST", "/v1/signout", "")
if err != nil {
return errors.Wrap(err, "making http request")
}
hasErr, message, err := checkRespErr(res)
if err != nil {
return errors.Wrap(err, "checking repsonse error")
}
if hasErr {
return errors.New(message)
}
return nil
}

View file

@ -1,11 +1,15 @@
package login
import (
"fmt"
"encoding/base64"
"strconv"
"github.com/dnote/cli/client"
"github.com/dnote/cli/core"
"github.com/dnote/cli/crypt"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/log"
"github.com/dnote/cli/utils"
"github.com/pkg/errors"
"github.com/spf13/cobra"
)
@ -25,36 +29,78 @@ func NewCmd(ctx infra.DnoteCtx) *cobra.Command {
return cmd
}
// Do dervies credentials on the client side and requests a session token from the server
func Do(ctx infra.DnoteCtx, email, password string) error {
presigninResp, err := client.GetPresignin(ctx, email)
if err != nil {
return errors.Wrap(err, "getting presiginin")
}
masterKey, authKey, err := crypt.MakeKeys([]byte(password), []byte(email), presigninResp.Iteration)
if err != nil {
return errors.Wrap(err, "making keys")
}
authKeyB64 := base64.StdEncoding.EncodeToString(authKey)
signinResp, err := client.Signin(ctx, email, authKeyB64)
if err != nil {
return errors.Wrap(err, "requesting session")
}
cipherKeyDec, err := crypt.AesGcmDecrypt(masterKey, signinResp.CipherKeyEnc)
if err != nil {
return errors.Wrap(err, "decrypting cipher key")
}
cipherKeyDecB64 := base64.StdEncoding.EncodeToString(cipherKeyDec)
db := ctx.DB
tx, err := db.Begin()
if err != nil {
return errors.Wrap(err, "beginning a transaction")
}
if err := core.UpsertSystem(tx, infra.SystemCipherKey, cipherKeyDecB64); err != nil {
return errors.Wrap(err, "saving enc key")
}
if err := core.UpsertSystem(tx, infra.SystemSessionKey, signinResp.Key); err != nil {
return errors.Wrap(err, "saving session key")
}
if err := core.UpsertSystem(tx, infra.SystemSessionKeyExpiry, strconv.FormatInt(signinResp.ExpiresAt, 10)); err != nil {
return errors.Wrap(err, "saving session key")
}
tx.Commit()
return nil
}
func newRun(ctx infra.DnoteCtx) core.RunEFunc {
return func(cmd *cobra.Command, args []string) error {
log.Plain("\n")
log.Plain(" _( )_( )_\n")
log.Plain(" (_ _ _)\n")
log.Plain(" (_) (__)\n\n")
log.Plain("Welcome to Dnote Cloud :)\n\n")
log.Plain("A home for your engineering microlessons\n")
log.Plain("You can register at https://dnote.io/cloud\n\n")
log.Printf("API key: ")
var apiKey string
fmt.Scanln(&apiKey)
if apiKey == "" {
return errors.New("Empty API key")
var email, password string
if err := utils.PromptInput("email", &email); err != nil {
return errors.Wrap(err, "getting email input")
}
if email == "" {
return errors.New("Email is empty")
}
config, err := core.ReadConfig(ctx)
if err != nil {
return err
if err := utils.PromptPassword("password", &password); err != nil {
return errors.Wrap(err, "getting password input")
}
if password == "" {
return errors.New("Password is empty")
}
config.APIKey = apiKey
err = core.WriteConfig(ctx, config)
if err != nil {
return errors.Wrap(err, "Failed to write to config file")
err := Do(ctx, email, password)
if errors.Cause(err) == client.ErrInvalidLogin {
log.Error("wrong login\n")
return nil
} else if err != nil {
return errors.Wrap(err, "logging in")
}
log.Success("configured\n")
log.Success("logged in\n")
return nil
}

82
cmd/logout/logout.go Normal file
View file

@ -0,0 +1,82 @@
package logout
import (
"database/sql"
"github.com/dnote/cli/client"
"github.com/dnote/cli/core"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/log"
"github.com/pkg/errors"
"github.com/spf13/cobra"
)
// ErrNotLoggedIn is an error for logging out when not logged in
var ErrNotLoggedIn = errors.New("not logged in")
var example = `
dnote logout`
// NewCmd returns a new logout command
func NewCmd(ctx infra.DnoteCtx) *cobra.Command {
cmd := &cobra.Command{
Use: "logout",
Short: "Logout from the server",
Example: example,
RunE: newRun(ctx),
}
return cmd
}
// Do performs logout
func Do(ctx infra.DnoteCtx) error {
db := ctx.DB
tx, err := db.Begin()
if err != nil {
return errors.Wrap(err, "beginning a transaction")
}
var key string
err = core.GetSystem(tx, infra.SystemSessionKey, &key)
if errors.Cause(err) == sql.ErrNoRows {
return ErrNotLoggedIn
} else if err != nil {
return errors.Wrap(err, "getting session key")
}
err = client.Signout(ctx, key)
if err != nil {
return errors.Wrap(err, "requesting logout")
}
if err := core.DeleteSystem(tx, infra.SystemCipherKey); err != nil {
return errors.Wrap(err, "deleting enc key")
}
if err := core.DeleteSystem(tx, infra.SystemSessionKey); err != nil {
return errors.Wrap(err, "deleting session key")
}
if err := core.DeleteSystem(tx, infra.SystemSessionKeyExpiry); err != nil {
return errors.Wrap(err, "deleting session key expiry")
}
tx.Commit()
return nil
}
func newRun(ctx infra.DnoteCtx) core.RunEFunc {
return func(cmd *cobra.Command, args []string) error {
err := Do(ctx)
if err == ErrNotLoggedIn {
log.Error("not logged in\n")
return nil
} else if err != nil {
return errors.Wrap(err, "logging out")
}
log.Success("logged out\n")
return nil
}
}

View file

@ -6,6 +6,7 @@ import (
"github.com/dnote/cli/client"
"github.com/dnote/cli/core"
"github.com/dnote/cli/crypt"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/log"
"github.com/dnote/cli/migrate"
@ -39,22 +40,20 @@ func NewCmd(ctx infra.DnoteCtx) *cobra.Command {
return cmd
}
func getLastSyncAt(tx *sql.Tx) (int, error) {
func getLastSyncAt(tx *infra.DB) (int, error) {
var ret int
err := tx.QueryRow("SELECT value FROM system WHERE key = ?", infra.SystemLastSyncAt).Scan(&ret)
if err != nil {
if err := core.GetSystem(tx, infra.SystemLastSyncAt, &ret); err != nil {
return ret, errors.Wrap(err, "querying last sync time")
}
return ret, nil
}
func getLastMaxUSN(tx *sql.Tx) (int, error) {
func getLastMaxUSN(tx *infra.DB) (int, error) {
var ret int
err := tx.QueryRow("SELECT value FROM system WHERE key = ?", infra.SystemLastMaxUSN).Scan(&ret)
if err != nil {
if err := core.GetSystem(tx, infra.SystemLastMaxUSN, &ret); err != nil {
return ret, errors.Wrap(err, "querying last user max_usn")
}
@ -75,7 +74,9 @@ func (l syncList) getLength() int {
return len(l.Notes) + len(l.Books) + len(l.ExpungedNotes) + len(l.ExpungedBooks)
}
func newSyncList(fragments []client.SyncFragment) syncList {
// processFragments categorizes items in sync fragments into a sync list. It also decrypts any
// encrypted data in sync fragments.
func processFragments(fragments []client.SyncFragment, cipherKey []byte) (syncList, error) {
notes := map[string]client.SyncFragNote{}
books := map[string]client.SyncFragBook{}
expungedNotes := map[string]bool{}
@ -85,9 +86,23 @@ func newSyncList(fragments []client.SyncFragment) syncList {
for _, fragment := range fragments {
for _, note := range fragment.Notes {
log.Debug("decrypting note %s\n", note.UUID)
bodyDec, err := crypt.AesGcmDecrypt(cipherKey, note.Body)
if err != nil {
return syncList{}, errors.Wrapf(err, "decrypting body for note %s", note.UUID)
}
note.Body = string(bodyDec)
notes[note.UUID] = note
}
for _, book := range fragment.Books {
log.Debug("decrypting book %s\n", book.UUID)
labelDec, err := crypt.AesGcmDecrypt(cipherKey, book.Label)
if err != nil {
return syncList{}, errors.Wrapf(err, "decrypting label for book %s", book.UUID)
}
book.Label = string(labelDec)
books[book.UUID] = book
}
for _, uuid := range fragment.ExpungedBooks {
@ -105,7 +120,7 @@ func newSyncList(fragments []client.SyncFragment) syncList {
}
}
return syncList{
sl := syncList{
Notes: notes,
Books: books,
ExpungedNotes: expungedNotes,
@ -113,30 +128,35 @@ func newSyncList(fragments []client.SyncFragment) syncList {
MaxUSN: maxUSN,
MaxCurrentTime: maxCurrentTime,
}
return sl, nil
}
// getSyncList gets a list of all sync fragments after the specified usn
// and aggregates them into a syncList data structure
func getSyncList(ctx infra.DnoteCtx, apiKey string, afterUSN int) (syncList, error) {
fragments, err := getSyncFragments(ctx, apiKey, afterUSN)
func getSyncList(ctx infra.DnoteCtx, afterUSN int) (syncList, error) {
fragments, err := getSyncFragments(ctx, afterUSN)
if err != nil {
return syncList{}, errors.Wrap(err, "getting sync fragments")
}
ret := newSyncList(fragments)
ret, err := processFragments(fragments, ctx.CipherKey)
if err != nil {
return syncList{}, errors.Wrap(err, "making sync list")
}
return ret, nil
}
// getSyncFragments repeatedly gets all sync fragments after the specified usn until there is no more new data
// remaining and returns the buffered list
func getSyncFragments(ctx infra.DnoteCtx, apiKey string, afterUSN int) ([]client.SyncFragment, error) {
func getSyncFragments(ctx infra.DnoteCtx, afterUSN int) ([]client.SyncFragment, error) {
var buf []client.SyncFragment
nextAfterUSN := afterUSN
for {
resp, err := client.GetSyncFragment(ctx, apiKey, nextAfterUSN)
resp, err := client.GetSyncFragment(ctx, nextAfterUSN)
if err != nil {
return buf, errors.Wrap(err, "getting sync fragment")
}
@ -159,7 +179,7 @@ func getSyncFragments(ctx infra.DnoteCtx, apiKey string, afterUSN int) ([]client
// resolveLabel resolves a book label conflict by repeatedly appending an increasing integer
// to the label until it finds a unique label. It returns the first non-conflicting label.
func resolveLabel(tx *sql.Tx, label string) (string, error) {
func resolveLabel(tx *infra.DB, label string) (string, error) {
var ret string
for i := 2; ; i++ {
@ -180,7 +200,7 @@ func resolveLabel(tx *sql.Tx, label string) (string, error) {
// mergeBook inserts or updates the given book in the local database.
// If a book with a duplicate label exists locally, it renames the duplicate by appending a number.
func mergeBook(tx *sql.Tx, b client.SyncFragBook, mode int) error {
func mergeBook(tx *infra.DB, b client.SyncFragBook, mode int) error {
var count int
if err := tx.QueryRow("SELECT count(*) FROM books WHERE label = ?", b.Label).Scan(&count); err != nil {
return errors.Wrapf(err, "checking for books with a duplicate label %s", b.Label)
@ -214,7 +234,7 @@ func mergeBook(tx *sql.Tx, b client.SyncFragBook, mode int) error {
return nil
}
func stepSyncBook(tx *sql.Tx, b client.SyncFragBook) error {
func stepSyncBook(tx *infra.DB, b client.SyncFragBook) error {
var localUSN int
var dirty bool
err := tx.QueryRow("SELECT usn, dirty FROM books WHERE uuid = ?", b.UUID).Scan(&localUSN, &dirty)
@ -238,7 +258,7 @@ func stepSyncBook(tx *sql.Tx, b client.SyncFragBook) error {
return nil
}
func mergeNote(tx *sql.Tx, serverNote client.SyncFragNote, localNote core.Note) error {
func mergeNote(tx *infra.DB, serverNote client.SyncFragNote, localNote core.Note) error {
var bookDeleted bool
err := tx.QueryRow("SELECT deleted FROM books WHERE uuid = ?", localNote.BookUUID).Scan(&bookDeleted)
if err != nil {
@ -269,7 +289,7 @@ func mergeNote(tx *sql.Tx, serverNote client.SyncFragNote, localNote core.Note)
return nil
}
func stepSyncNote(tx *sql.Tx, n client.SyncFragNote) error {
func stepSyncNote(tx *infra.DB, n client.SyncFragNote) error {
var localNote core.Note
err := tx.QueryRow("SELECT usn, book_uuid, dirty, deleted FROM notes WHERE uuid = ?", n.UUID).
Scan(&localNote.USN, &localNote.BookUUID, &localNote.Dirty, &localNote.Deleted)
@ -293,7 +313,7 @@ func stepSyncNote(tx *sql.Tx, n client.SyncFragNote) error {
return nil
}
func fullSyncNote(tx *sql.Tx, n client.SyncFragNote) error {
func fullSyncNote(tx *infra.DB, n client.SyncFragNote) error {
var localNote core.Note
err := tx.QueryRow("SELECT usn,book_uuid, dirty, deleted FROM notes WHERE uuid = ?", n.UUID).
Scan(&localNote.USN, &localNote.BookUUID, &localNote.Dirty, &localNote.Deleted)
@ -317,7 +337,7 @@ func fullSyncNote(tx *sql.Tx, n client.SyncFragNote) error {
return nil
}
func syncDeleteNote(tx *sql.Tx, noteUUID string) error {
func syncDeleteNote(tx *infra.DB, noteUUID string) error {
var localUSN int
var dirty bool
err := tx.QueryRow("SELECT usn, dirty FROM notes WHERE uuid = ?", noteUUID).Scan(&localUSN, &dirty)
@ -342,7 +362,7 @@ func syncDeleteNote(tx *sql.Tx, noteUUID string) error {
}
// checkNotesPristine checks that none of the notes in the given book are dirty
func checkNotesPristine(tx *sql.Tx, bookUUID string) (bool, error) {
func checkNotesPristine(tx *infra.DB, bookUUID string) (bool, error) {
var count int
if err := tx.QueryRow("SELECT count(*) FROM notes WHERE book_uuid = ? AND dirty = ?", bookUUID, true).Scan(&count); err != nil {
return false, errors.Wrapf(err, "counting notes that are dirty in book %s", bookUUID)
@ -355,7 +375,7 @@ func checkNotesPristine(tx *sql.Tx, bookUUID string) (bool, error) {
return true, nil
}
func syncDeleteBook(tx *sql.Tx, bookUUID string) error {
func syncDeleteBook(tx *infra.DB, bookUUID string) error {
var localUSN int
var dirty bool
err := tx.QueryRow("SELECT usn, dirty FROM books WHERE uuid = ?", bookUUID).Scan(&localUSN, &dirty)
@ -401,7 +421,7 @@ func syncDeleteBook(tx *sql.Tx, bookUUID string) error {
return nil
}
func fullSyncBook(tx *sql.Tx, b client.SyncFragBook) error {
func fullSyncBook(tx *infra.DB, b client.SyncFragBook) error {
var localUSN int
var dirty bool
err := tx.QueryRow("SELECT usn, dirty FROM books WHERE uuid = ?", b.UUID).Scan(&localUSN, &dirty)
@ -453,7 +473,7 @@ func checkBookInList(uuid string, list *syncList) bool {
// judging by the full list of resources in the server. Concretely, the only acceptable
// situation in which a local note is not present in the server is if it is new and has not been
// uploaded (i.e. dirty and usn is 0). Otherwise, it is a result of some kind of error and should be cleaned.
func cleanLocalNotes(tx *sql.Tx, fullList *syncList) error {
func cleanLocalNotes(tx *infra.DB, fullList *syncList) error {
rows, err := tx.Query("SELECT uuid, usn, dirty FROM notes")
if err != nil {
return errors.Wrap(err, "getting local notes")
@ -479,7 +499,7 @@ func cleanLocalNotes(tx *sql.Tx, fullList *syncList) error {
}
// cleanLocalBooks deletes from the local database any books that are in invalid state
func cleanLocalBooks(tx *sql.Tx, fullList *syncList) error {
func cleanLocalBooks(tx *infra.DB, fullList *syncList) error {
rows, err := tx.Query("SELECT uuid, usn, dirty FROM books")
if err != nil {
return errors.Wrap(err, "getting local books")
@ -504,11 +524,11 @@ func cleanLocalBooks(tx *sql.Tx, fullList *syncList) error {
return nil
}
func fullSync(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) error {
func fullSync(ctx infra.DnoteCtx, tx *infra.DB) error {
log.Debug("performing a full sync\n")
log.Info("resolving delta.")
list, err := getSyncList(ctx, apiKey, 0)
list, err := getSyncList(ctx, 0)
if err != nil {
return errors.Wrap(err, "getting sync list")
}
@ -555,12 +575,12 @@ func fullSync(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) error {
return nil
}
func stepSync(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string, afterUSN int) error {
func stepSync(ctx infra.DnoteCtx, tx *infra.DB, afterUSN int) error {
log.Debug("performing a step sync\n")
log.Info("resolving delta.")
list, err := getSyncList(ctx, apiKey, afterUSN)
list, err := getSyncList(ctx, afterUSN)
if err != nil {
return errors.Wrap(err, "getting sync list")
}
@ -599,7 +619,7 @@ func stepSync(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string, afterUSN int) error
return nil
}
func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
func sendBooks(ctx infra.DnoteCtx, tx *infra.DB) (bool, error) {
isBehind := false
rows, err := tx.Query("SELECT uuid, label, usn, deleted FROM books WHERE dirty")
@ -629,7 +649,7 @@ func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
continue
} else {
resp, err := client.CreateBook(ctx, apiKey, book.Label)
resp, err := client.CreateBook(ctx, book.Label)
if err != nil {
return isBehind, errors.Wrap(err, "creating a book")
}
@ -655,7 +675,7 @@ func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
}
} else {
if book.Deleted {
resp, err := client.DeleteBook(ctx, apiKey, book.UUID)
resp, err := client.DeleteBook(ctx, book.UUID)
if err != nil {
return isBehind, errors.Wrap(err, "deleting a book")
}
@ -667,7 +687,7 @@ func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
respUSN = resp.Book.USN
} else {
resp, err := client.UpdateBook(ctx, apiKey, book.Label, book.UUID)
resp, err := client.UpdateBook(ctx, book.Label, book.UUID)
if err != nil {
return isBehind, errors.Wrap(err, "updating a book")
}
@ -703,7 +723,7 @@ func sendBooks(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
return isBehind, nil
}
func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
func sendNotes(ctx infra.DnoteCtx, tx *infra.DB) (bool, error) {
isBehind := false
rows, err := tx.Query("SELECT uuid, book_uuid, body, public, deleted, usn FROM notes WHERE dirty")
@ -734,7 +754,7 @@ func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
continue
} else {
resp, err := client.CreateNote(ctx, apiKey, note.BookUUID, note.Body)
resp, err := client.CreateNote(ctx, note.BookUUID, note.Body)
if err != nil {
return isBehind, errors.Wrap(err, "creating a note")
}
@ -755,7 +775,7 @@ func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
}
} else {
if note.Deleted {
resp, err := client.DeleteNote(ctx, apiKey, note.UUID)
resp, err := client.DeleteNote(ctx, note.UUID)
if err != nil {
return isBehind, errors.Wrap(err, "deleting a note")
}
@ -767,7 +787,7 @@ func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
respUSN = resp.Result.USN
} else {
resp, err := client.UpdateNote(ctx, apiKey, note.UUID, note.BookUUID, note.Body, note.Public)
resp, err := client.UpdateNote(ctx, note.UUID, note.BookUUID, note.Body, note.Public)
if err != nil {
return isBehind, errors.Wrap(err, "updating a note")
}
@ -803,7 +823,7 @@ func sendNotes(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
return isBehind, nil
}
func sendChanges(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
func sendChanges(ctx infra.DnoteCtx, tx *infra.DB) (bool, error) {
log.Info("sending changes.")
var delta int
@ -811,12 +831,12 @@ func sendChanges(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
fmt.Printf(" (total %d).", delta)
behind1, err := sendBooks(ctx, tx, apiKey)
behind1, err := sendBooks(ctx, tx)
if err != nil {
return behind1, errors.Wrap(err, "sending books")
}
behind2, err := sendNotes(ctx, tx, apiKey)
behind2, err := sendNotes(ctx, tx)
if err != nil {
return behind2, errors.Wrap(err, "sending notes")
}
@ -828,25 +848,23 @@ func sendChanges(ctx infra.DnoteCtx, tx *sql.Tx, apiKey string) (bool, error) {
return isBehind, nil
}
func updateLastMaxUSN(tx *sql.Tx, val int) error {
_, err := tx.Exec("UPDATE system SET value = ? WHERE key = ?", val, infra.SystemLastMaxUSN)
if err != nil {
func updateLastMaxUSN(tx *infra.DB, val int) error {
if err := core.UpdateSystem(tx, infra.SystemLastMaxUSN, val); err != nil {
return errors.Wrapf(err, "updating %s", infra.SystemLastMaxUSN)
}
return nil
}
func updateLastSyncAt(tx *sql.Tx, val int64) error {
_, err := tx.Exec("UPDATE system SET value = ? WHERE key = ?", val, infra.SystemLastSyncAt)
if err != nil {
func updateLastSyncAt(tx *infra.DB, val int64) error {
if err := core.UpdateSystem(tx, infra.SystemLastSyncAt, val); err != nil {
return errors.Wrapf(err, "updating %s", infra.SystemLastSyncAt)
}
return nil
}
func saveSyncState(tx *sql.Tx, serverTime int64, serverMaxUSN int) error {
func saveSyncState(tx *infra.DB, serverTime int64, serverMaxUSN int) error {
if err := updateLastMaxUSN(tx, serverMaxUSN); err != nil {
return errors.Wrap(err, "updating last max usn")
}
@ -859,26 +877,20 @@ func saveSyncState(tx *sql.Tx, serverTime int64, serverMaxUSN int) error {
func newRun(ctx infra.DnoteCtx) core.RunEFunc {
return func(cmd *cobra.Command, args []string) error {
config, err := core.ReadConfig(ctx)
if err != nil {
return errors.Wrap(err, "reading the config")
}
if config.APIKey == "" {
log.Error("login required. please run `dnote login`\n")
return nil
if ctx.SessionKey == "" || ctx.CipherKey == nil {
return errors.New("not logged in")
}
if err := migrate.Run(ctx, migrate.RemoteSequence, migrate.RemoteMode); err != nil {
return errors.Wrap(err, "running remote migrations")
}
db := ctx.DB
tx, err := db.Begin()
tx, err := ctx.DB.Begin()
if err != nil {
return errors.Wrap(err, "beginning a transaction")
}
syncState, err := client.GetSyncState(config.APIKey, ctx)
syncState, err := client.GetSyncState(ctx)
if err != nil {
return errors.Wrap(err, "getting the sync state from the server")
}
@ -895,9 +907,9 @@ func newRun(ctx infra.DnoteCtx) core.RunEFunc {
var syncErr error
if isFullSync || lastSyncAt < syncState.FullSyncBefore {
syncErr = fullSync(ctx, tx, config.APIKey)
syncErr = fullSync(ctx, tx)
} else if lastMaxUSN != syncState.MaxUSN {
syncErr = stepSync(ctx, tx, config.APIKey, lastMaxUSN)
syncErr = stepSync(ctx, tx, lastMaxUSN)
} else {
// if no need to sync from the server, simply update the last sync timestamp and proceed to send changes
err = updateLastSyncAt(tx, syncState.CurrentTime)
@ -907,10 +919,10 @@ func newRun(ctx infra.DnoteCtx) core.RunEFunc {
}
if syncErr != nil {
tx.Rollback()
return errors.Wrap(err, "syncing changes from the server")
return errors.Wrap(syncErr, "syncing changes from the server")
}
isBehind, err := sendChanges(ctx, tx, config.APIKey)
isBehind, err := sendChanges(ctx, tx)
if err != nil {
tx.Rollback()
return errors.Wrap(err, "sending changes")
@ -926,7 +938,7 @@ func newRun(ctx infra.DnoteCtx) core.RunEFunc {
return errors.Wrap(err, "getting the new last max_usn")
}
err = stepSync(ctx, tx, config.APIKey, updatedLastMaxUSN)
err = stepSync(ctx, tx, updatedLastMaxUSN)
if err != nil {
tx.Rollback()
return errors.Wrap(err, "performing the follow-up step sync")

View file

@ -11,12 +11,87 @@ import (
"github.com/dnote/cli/client"
"github.com/dnote/cli/core"
"github.com/dnote/cli/crypt"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/testutils"
"github.com/dnote/cli/utils"
"github.com/pkg/errors"
)
var cipherKey = []byte("AES256Key-32Characters1234567890")
func TestProcessFragments(t *testing.T) {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
defer testutils.TeardownEnv(ctx)
fragments := []client.SyncFragment{
client.SyncFragment{
FragMaxUSN: 10,
UserMaxUSN: 10,
CurrentTime: 1550436136,
Notes: []client.SyncFragNote{
client.SyncFragNote{
UUID: "45546de0-40ed-45cf-9bfc-62ce729a7d3d",
Body: "7GgIppDdxDn+4DUoVoLXbncZDRqXGwbDVNF/eCssu+1BXMdq+HAziJHGgK7drdcIBtYDDXj0OwHz9dQDDOyWeNqkLWEIQ2Roygs229dRxdO3Z6ST+qSOr/9TTjDlFxydF5Ps7nAXdN9KVxH8FKIZDsxJ45qeLKpQK/6poAM39BCOiysqAXJQz9ngOJiqImAuftS6d/XhwX77QvnM91VCKK0tFmsMdDDw0J9QMwnlYU1CViHy1Hdhhcf9Ea38Mj4SCrWMPscXyP2fpAu5ukbIK3vS2pvbnH5vC8ZuvihrQif1BsiwfYmN981mLYs069Dn4B72qcXPwU7qrN3V0k57JGcAlTiEoOD5QowyraensQlR1doorLb43SjTiJLItougn5K5QPRiHuNxfv39pa7A0gKA1n/3UhG/SBuCpDuPYjwmBkvkzCKJNgpbLQ8p29JXMQcWrm4e9GfnVjMhAEtxttIta3MN6EcYG7cB1dJ04OLYVcJuRA==",
},
client.SyncFragNote{
UUID: "a25a5336-afe9-46c4-b881-acab911c0bc3",
Body: "WGzcYA6kLuUFEU7HLTDJt7UWF7fEmbCPHfC16VBrAyfT2wDejXbIuFpU5L7g0aU=",
},
},
Books: []client.SyncFragBook{
client.SyncFragBook{
UUID: "e8ac6f25-d95b-435a-9fae-094f7506a5ac",
Label: "qBrSrAcnTUHu51bIrv6jSA/dNffr/kRlIg+MklxeQQ==",
},
client.SyncFragBook{
UUID: "05fd8b95-ddcd-4071-9380-4358ffb8a436",
Label: "uHWoBFdKT78gTkFR7qhyzZkrn59c8ktEa8idrLkksKzIQ3VVAXxq0QZp7Uc=",
},
},
ExpungedNotes: []string{},
ExpungedBooks: []string{},
},
}
// exec
sl, err := processFragments(fragments, cipherKey)
if err != nil {
t.Fatalf(errors.Wrap(err, "executing").Error())
}
expected := syncList{
Notes: map[string]client.SyncFragNote{
"45546de0-40ed-45cf-9bfc-62ce729a7d3d": client.SyncFragNote{
UUID: "45546de0-40ed-45cf-9bfc-62ce729a7d3d",
Body: "Lorem ipsum dolor sit amet, consectetur adipiscing elit.\n Donec ac libero efficitur, posuere dui non, egestas lectus.\n Aliquam urna ligula, sagittis eu volutpat vel, consequat et augue.\n\n Ut mi urna, dignissim a ex eget, venenatis accumsan sem. Praesent facilisis, ligula hendrerit auctor varius, mauris metus hendrerit dolor, sit amet pulvinar.",
},
"a25a5336-afe9-46c4-b881-acab911c0bc3": client.SyncFragNote{
UUID: "a25a5336-afe9-46c4-b881-acab911c0bc3",
Body: "foo bar baz quz\nqux",
},
},
Books: map[string]client.SyncFragBook{
"e8ac6f25-d95b-435a-9fae-094f7506a5ac": client.SyncFragBook{
UUID: "e8ac6f25-d95b-435a-9fae-094f7506a5ac",
Label: "foo",
},
"05fd8b95-ddcd-4071-9380-4358ffb8a436": client.SyncFragBook{
UUID: "05fd8b95-ddcd-4071-9380-4358ffb8a436",
Label: "foo-bar-baz-1000",
},
},
ExpungedNotes: map[string]bool{},
ExpungedBooks: map[string]bool{},
MaxUSN: 10,
MaxCurrentTime: 1550436136,
}
// test
testutils.AssertDeepEqual(t, sl, expected, "syncList mismatch")
}
func TestGetLastSyncAt(t *testing.T) {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
@ -1769,6 +1844,7 @@ func TestMergeBook(t *testing.T) {
func TestSaveServerState(t *testing.T) {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
defer testutils.TeardownEnv(ctx)
db := ctx.DB
@ -1812,6 +1888,7 @@ func TestSaveServerState(t *testing.T) {
func TestSendBooks(t *testing.T) {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
defer testutils.TeardownEnv(ctx)
db := ctx.DB
@ -1846,9 +1923,9 @@ func TestSendBooks(t *testing.T) {
var updatesUUIDs []string
var deletedUUIDs []string
// fire up a test server
// fire up a test server. It decrypts the payload for test purposes.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/v1/books" && r.Method == "POST" {
if r.URL.String() == "/v2/books" && r.Method == "POST" {
var payload client.CreateBookPayload
err := json.NewDecoder(r.Body).Decode(&payload)
@ -1857,15 +1934,22 @@ func TestSendBooks(t *testing.T) {
return
}
createdLabels = append(createdLabels, payload.Name)
labelDec, err := crypt.AesGcmDecrypt(cipherKey, payload.Name)
if err != nil {
t.Fatalf(errors.Wrap(err, "decrypting label").Error())
}
labelDecStr := string(labelDec)
createdLabels = append(createdLabels, labelDecStr)
resp := client.CreateBookResp{
Book: client.RespBook{
UUID: fmt.Sprintf("server-%s-uuid", payload.Name),
UUID: fmt.Sprintf("server-%s-uuid", labelDecStr),
},
}
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -1879,14 +1963,14 @@ func TestSendBooks(t *testing.T) {
uuid := p[3]
updatesUUIDs = append(updatesUUIDs, uuid)
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{}"))
return
} else if r.Method == "DELETE" {
uuid := p[3]
deletedUUIDs = append(deletedUUIDs, uuid)
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{}"))
return
}
@ -1904,13 +1988,16 @@ func TestSendBooks(t *testing.T) {
t.Fatalf(errors.Wrap(err, "beginning a transaction").Error())
}
if _, err := sendBooks(ctx, tx, "mockAPIKey"); err != nil {
if _, err := sendBooks(ctx, tx); err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, "executing").Error())
}
tx.Commit()
// test
// First, decrypt data so that they can be asserted
sort.SliceStable(createdLabels, func(i, j int) bool {
return strings.Compare(createdLabels[i], createdLabels[j]) < 0
})
@ -1964,7 +2051,7 @@ func TestSendBooks(t *testing.T) {
func TestSendBooks_isBehind(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/v1/books" && r.Method == "POST" {
if r.URL.String() == "/v2/books" && r.Method == "POST" {
var payload client.CreateBookPayload
err := json.NewDecoder(r.Body).Decode(&payload)
@ -1979,7 +2066,7 @@ func TestSendBooks_isBehind(t *testing.T) {
},
}
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -1996,7 +2083,7 @@ func TestSendBooks_isBehind(t *testing.T) {
},
}
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -2009,7 +2096,7 @@ func TestSendBooks_isBehind(t *testing.T) {
},
}
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -2041,6 +2128,7 @@ func TestSendBooks_isBehind(t *testing.T) {
func() {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
ctx.APIEndpoint = ts.URL
defer testutils.TeardownEnv(ctx)
@ -2055,7 +2143,7 @@ func TestSendBooks_isBehind(t *testing.T) {
t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error())
}
isBehind, err := sendBooks(ctx, tx, "mockAPIKey")
isBehind, err := sendBooks(ctx, tx)
if err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error())
@ -2088,6 +2176,7 @@ func TestSendBooks_isBehind(t *testing.T) {
func() {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
ctx.APIEndpoint = ts.URL
defer testutils.TeardownEnv(ctx)
@ -2102,7 +2191,7 @@ func TestSendBooks_isBehind(t *testing.T) {
t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error())
}
isBehind, err := sendBooks(ctx, tx, "mockAPIKey")
isBehind, err := sendBooks(ctx, tx)
if err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error())
@ -2135,6 +2224,7 @@ func TestSendBooks_isBehind(t *testing.T) {
func() {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
ctx.APIEndpoint = ts.URL
defer testutils.TeardownEnv(ctx)
@ -2149,7 +2239,7 @@ func TestSendBooks_isBehind(t *testing.T) {
t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error())
}
isBehind, err := sendBooks(ctx, tx, "mockAPIKey")
isBehind, err := sendBooks(ctx, tx)
if err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error())
@ -2169,6 +2259,7 @@ func TestSendBooks_isBehind(t *testing.T) {
func TestSendNotes(t *testing.T) {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
defer testutils.TeardownEnv(ctx)
db := ctx.DB
@ -2203,9 +2294,9 @@ func TestSendNotes(t *testing.T) {
var updatedUUIDs []string
var deletedUUIDs []string
// fire up a test server
// fire up a test server. It decrypts the payload for test purposes.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/v1/notes" && r.Method == "POST" {
if r.URL.String() == "/v2/notes" && r.Method == "POST" {
var payload client.CreateNotePayload
err := json.NewDecoder(r.Body).Decode(&payload)
@ -2214,15 +2305,21 @@ func TestSendNotes(t *testing.T) {
return
}
createdBodys = append(createdBodys, payload.Body)
bodyDec, err := crypt.AesGcmDecrypt(cipherKey, payload.Body)
if err != nil {
t.Fatalf(errors.Wrap(err, "decrypting body").Error())
}
bodyDecStr := string(bodyDec)
createdBodys = append(createdBodys, bodyDecStr)
resp := client.CreateNoteResp{
Result: client.RespNote{
UUID: fmt.Sprintf("server-%s-uuid", payload.Body),
UUID: fmt.Sprintf("server-%s-uuid", bodyDecStr),
},
}
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -2236,14 +2333,14 @@ func TestSendNotes(t *testing.T) {
uuid := p[3]
updatedUUIDs = append(updatedUUIDs, uuid)
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{}"))
return
} else if r.Method == "DELETE" {
uuid := p[3]
deletedUUIDs = append(deletedUUIDs, uuid)
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{}"))
return
}
@ -2261,7 +2358,7 @@ func TestSendNotes(t *testing.T) {
t.Fatalf(errors.Wrap(err, "beginning a transaction").Error())
}
if _, err := sendNotes(ctx, tx, "mockAPIKey"); err != nil {
if _, err := sendNotes(ctx, tx); err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, "executing").Error())
}
@ -2327,7 +2424,7 @@ func TestSendNotes_isBehind(t *testing.T) {
},
}
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -2344,7 +2441,7 @@ func TestSendNotes_isBehind(t *testing.T) {
},
}
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -2357,7 +2454,7 @@ func TestSendNotes_isBehind(t *testing.T) {
},
}
w.Header().Set("Body-Type", "application/json")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -2389,6 +2486,7 @@ func TestSendNotes_isBehind(t *testing.T) {
func() {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
ctx.APIEndpoint = ts.URL
defer testutils.TeardownEnv(ctx)
@ -2404,7 +2502,7 @@ func TestSendNotes_isBehind(t *testing.T) {
t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error())
}
isBehind, err := sendNotes(ctx, tx, "mockAPIKey")
isBehind, err := sendNotes(ctx, tx)
if err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error())
@ -2437,6 +2535,7 @@ func TestSendNotes_isBehind(t *testing.T) {
func() {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
ctx.APIEndpoint = ts.URL
defer testutils.TeardownEnv(ctx)
@ -2452,7 +2551,7 @@ func TestSendNotes_isBehind(t *testing.T) {
t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error())
}
isBehind, err := sendNotes(ctx, tx, "mockAPIKey")
isBehind, err := sendNotes(ctx, tx)
if err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error())
@ -2485,6 +2584,7 @@ func TestSendNotes_isBehind(t *testing.T) {
func() {
// set up
ctx := testutils.InitEnv(t, "../../tmp", "../../testutils/fixtures/schema.sql", true)
testutils.Login(t, &ctx)
ctx.APIEndpoint = ts.URL
defer testutils.TeardownEnv(ctx)
@ -2500,7 +2600,7 @@ func TestSendNotes_isBehind(t *testing.T) {
t.Fatalf(errors.Wrap(err, fmt.Sprintf("beginning a transaction for test case %d", idx)).Error())
}
isBehind, err := sendNotes(ctx, tx, "mockAPIKey")
isBehind, err := sendNotes(ctx, tx)
if err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, fmt.Sprintf("executing for test case %d", idx)).Error())

View file

@ -11,6 +11,7 @@ import (
"time"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/log"
"github.com/dnote/cli/utils"
"github.com/pkg/errors"
"github.com/satori/go.uuid"
@ -267,9 +268,9 @@ func GetEditorInput(ctx infra.DnoteCtx, fpath string, content *string) error {
return nil
}
func initSystemKV(tx *sql.Tx, key string, val string) error {
func initSystemKV(db *infra.DB, key string, val string) error {
var count int
if err := tx.QueryRow("SELECT count(*) FROM system WHERE key = ?", key).Scan(&count); err != nil {
if err := db.QueryRow("SELECT count(*) FROM system WHERE key = ?", key).Scan(&count); err != nil {
return errors.Wrapf(err, "counting %s", key)
}
@ -277,8 +278,8 @@ func initSystemKV(tx *sql.Tx, key string, val string) error {
return nil
}
if _, err := tx.Exec("INSERT INTO system (key, value) VALUES (?, ?)", key, val); err != nil {
tx.Rollback()
if _, err := db.Exec("INSERT INTO system (key, value) VALUES (?, ?)", key, val); err != nil {
db.Rollback()
return errors.Wrapf(err, "inserting %s %s", key, val)
}
@ -310,3 +311,30 @@ func InitSystem(ctx infra.DnoteCtx) error {
return nil
}
// GetValidSession returns a session key from the local storage if one exists and is not expired
// If one does not exist or is expired, it prints out an instruction and returns false
func GetValidSession(ctx infra.DnoteCtx) (string, bool, error) {
db := ctx.DB
var sessionKey string
var sessionKeyExpires int64
if err := GetSystem(db, infra.SystemSessionKey, &sessionKey); err != nil {
return "", false, errors.Wrap(err, "getting session key")
}
if err := GetSystem(db, infra.SystemSessionKeyExpiry, &sessionKeyExpires); err != nil {
return "", false, errors.Wrap(err, "getting session key expiry")
}
if sessionKey == "" {
log.Error("login required. please run `dnote login`\n")
return "", false, nil
}
if sessionKeyExpires < time.Now().Unix() {
log.Error("sesison expired. please run `dnote login`\n")
return "", false, nil
}
return sessionKey, true, nil
}

View file

@ -1,8 +1,7 @@
package core
import (
"database/sql"
"github.com/dnote/cli/infra"
"github.com/pkg/errors"
)
@ -45,8 +44,8 @@ func NewNote(uuid, bookUUID, body string, addedOn, editedOn int64, usn int, publ
}
// Insert inserts a new note
func (n Note) Insert(tx *sql.Tx) error {
_, err := tx.Exec("INSERT INTO notes (uuid, book_uuid, body, added_on, edited_on, usn, public, deleted, dirty) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
func (n Note) Insert(db *infra.DB) error {
_, err := db.Exec("INSERT INTO notes (uuid, book_uuid, body, added_on, edited_on, usn, public, deleted, dirty) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
n.UUID, n.BookUUID, n.Body, n.AddedOn, n.EditedOn, n.USN, n.Public, n.Deleted, n.Dirty)
if err != nil {
@ -57,8 +56,8 @@ func (n Note) Insert(tx *sql.Tx) error {
}
// Update updates the note with the given data
func (n Note) Update(tx *sql.Tx) error {
_, err := tx.Exec("UPDATE notes SET book_uuid = ?, body = ?, added_on = ?, edited_on = ?, usn = ?, public = ?, deleted = ?, dirty = ? WHERE uuid = ?",
func (n Note) Update(db *infra.DB) error {
_, err := db.Exec("UPDATE notes SET book_uuid = ?, body = ?, added_on = ?, edited_on = ?, usn = ?, public = ?, deleted = ?, dirty = ? WHERE uuid = ?",
n.BookUUID, n.Body, n.AddedOn, n.EditedOn, n.USN, n.Public, n.Deleted, n.Dirty, n.UUID)
if err != nil {
@ -69,8 +68,8 @@ func (n Note) Update(tx *sql.Tx) error {
}
// UpdateUUID updates the uuid of a book
func (n *Note) UpdateUUID(tx *sql.Tx, newUUID string) error {
_, err := tx.Exec("UPDATE notes SET uuid = ? WHERE uuid = ?", newUUID, n.UUID)
func (n *Note) UpdateUUID(db *infra.DB, newUUID string) error {
_, err := db.Exec("UPDATE notes SET uuid = ? WHERE uuid = ?", newUUID, n.UUID)
if err != nil {
return errors.Wrapf(err, "updating note uuid from '%s' to '%s'", n.UUID, newUUID)
@ -82,8 +81,8 @@ func (n *Note) UpdateUUID(tx *sql.Tx, newUUID string) error {
}
// Expunge hard-deletes the note from the database
func (n Note) Expunge(tx *sql.Tx) error {
_, err := tx.Exec("DELETE FROM notes WHERE uuid = ?", n.UUID)
func (n Note) Expunge(db *infra.DB) error {
_, err := db.Exec("DELETE FROM notes WHERE uuid = ?", n.UUID)
if err != nil {
return errors.Wrap(err, "expunging a note locally")
}
@ -103,8 +102,8 @@ func NewBook(uuid, label string, usn int, deleted, dirty bool) Book {
}
// Insert inserts a new book
func (b Book) Insert(tx *sql.Tx) error {
_, err := tx.Exec("INSERT INTO books (uuid, label, usn, dirty, deleted) VALUES (?, ?, ?, ?, ?)",
func (b Book) Insert(db *infra.DB) error {
_, err := db.Exec("INSERT INTO books (uuid, label, usn, dirty, deleted) VALUES (?, ?, ?, ?, ?)",
b.UUID, b.Label, b.USN, b.Dirty, b.Deleted)
if err != nil {
@ -115,8 +114,8 @@ func (b Book) Insert(tx *sql.Tx) error {
}
// Update updates the book with the given data
func (b Book) Update(tx *sql.Tx) error {
_, err := tx.Exec("UPDATE books SET label = ?, usn = ?, dirty = ?, deleted = ? WHERE uuid = ?",
func (b Book) Update(db *infra.DB) error {
_, err := db.Exec("UPDATE books SET label = ?, usn = ?, dirty = ?, deleted = ? WHERE uuid = ?",
b.Label, b.USN, b.Dirty, b.Deleted, b.UUID)
if err != nil {
@ -127,8 +126,8 @@ func (b Book) Update(tx *sql.Tx) error {
}
// UpdateUUID updates the uuid of a book
func (b *Book) UpdateUUID(tx *sql.Tx, newUUID string) error {
_, err := tx.Exec("UPDATE books SET uuid = ? WHERE uuid = ?", newUUID, b.UUID)
func (b *Book) UpdateUUID(db *infra.DB, newUUID string) error {
_, err := db.Exec("UPDATE books SET uuid = ? WHERE uuid = ?", newUUID, b.UUID)
if err != nil {
return errors.Wrapf(err, "updating book uuid from '%s' to '%s'", b.UUID, newUUID)
@ -140,8 +139,8 @@ func (b *Book) UpdateUUID(tx *sql.Tx, newUUID string) error {
}
// Expunge hard-deletes the book from the database
func (b Book) Expunge(tx *sql.Tx) error {
_, err := tx.Exec("DELETE FROM books WHERE uuid = ?", b.UUID)
func (b Book) Expunge(db *infra.DB) error {
_, err := db.Exec("DELETE FROM books WHERE uuid = ?", b.UUID)
if err != nil {
return errors.Wrap(err, "expunging a book locally")
}

85
core/operations.go Normal file
View file

@ -0,0 +1,85 @@
package core
import (
"encoding/base64"
"github.com/dnote/cli/infra"
"github.com/pkg/errors"
)
// InsertSystem inserets a system configuration
func InsertSystem(db *infra.DB, key, val string) error {
if _, err := db.Exec("INSERT INTO system (key, value) VALUES (? , ?);", key, val); err != nil {
return errors.Wrap(err, "saving system config")
}
return nil
}
// UpsertSystem inserts or updates a system configuration
func UpsertSystem(db *infra.DB, key, val string) error {
var count int
if err := db.QueryRow("SELECT count(*) FROM system WHERE key = ?", key).Scan(&count); err != nil {
return errors.Wrap(err, "counting system record")
}
if count == 0 {
if _, err := db.Exec("INSERT INTO system (key, value) VALUES (? , ?);", key, val); err != nil {
return errors.Wrap(err, "saving system config")
}
} else {
if _, err := db.Exec("UPDATE system SET value = ? WHERE key = ?", val, key); err != nil {
return errors.Wrap(err, "updating system config")
}
}
return nil
}
// UpdateSystem updates a system configuration
func UpdateSystem(db *infra.DB, key, val interface{}) error {
if _, err := db.Exec("UPDATE system SET value = ? WHERE key = ?", val, key); err != nil {
return errors.Wrap(err, "updating system config")
}
return nil
}
// GetSystem scans the given system configuration record onto the destination
func GetSystem(db *infra.DB, key string, dest interface{}) error {
if err := db.QueryRow("SELECT value FROM system WHERE key = ?", key).Scan(dest); err != nil {
return errors.Wrap(err, "finding system configuration record")
}
return nil
}
// DeleteSystem delets the given system record
func DeleteSystem(db *infra.DB, key string) error {
if _, err := db.Exec("DELETE FROM system WHERE key = ?", key); err != nil {
return errors.Wrap(err, "deleting system config")
}
return nil
}
// GetCipherKey retrieves the cipher key and decode the base64 into bytes.
func GetCipherKey(ctx infra.DnoteCtx) ([]byte, error) {
db, err := ctx.DB.Begin()
if err != nil {
return nil, errors.Wrap(err, "beginning transaction")
}
var cipherKeyB64 string
err = GetSystem(db, infra.SystemCipherKey, &cipherKeyB64)
if err != nil {
return []byte{}, errors.Wrap(err, "getting enc key")
}
cipherKey, err := base64.StdEncoding.DecodeString(cipherKeyB64)
if err != nil {
return nil, errors.Wrap(err, "decoding cipherKey from base64")
}
return cipherKey, nil
}

221
core/operations_test.go Normal file
View file

@ -0,0 +1,221 @@
package core
import (
"fmt"
"testing"
"github.com/dnote/cli/testutils"
"github.com/pkg/errors"
)
func TestInsertSystem(t *testing.T) {
testCases := []struct {
key string
val string
}{
{
key: "foo",
val: "1558089284",
},
{
key: "baz",
val: "quz",
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("insert %s %s", tc.key, tc.val), func(t *testing.T) {
// Setup
ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true)
defer testutils.TeardownEnv(ctx)
// execute
db := ctx.DB
tx, err := db.Begin()
if err != nil {
t.Fatalf(errors.Wrap(err, "beginning a transaction").Error())
}
if err := InsertSystem(tx, tc.key, tc.val); err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, "executing for test case").Error())
}
tx.Commit()
// test
var key, val string
testutils.MustScan(t, "getting the saved record",
db.QueryRow("SELECT key, value FROM system WHERE key = ?", tc.key), &key, &val)
testutils.AssertEqual(t, key, tc.key, "key mismatch for test case")
testutils.AssertEqual(t, val, tc.val, "val mismatch for test case")
})
}
}
func TestUpsertSystem(t *testing.T) {
testCases := []struct {
key string
val string
countDelta int
}{
{
key: "foo",
val: "1558089284",
countDelta: 1,
},
{
key: "baz",
val: "quz2",
countDelta: 0,
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("insert %s %s", tc.key, tc.val), func(t *testing.T) {
// Setup
ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true)
defer testutils.TeardownEnv(ctx)
db := ctx.DB
testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "baz", "quz")
var initialSystemCount int
testutils.MustScan(t, "counting records", db.QueryRow("SELECT count(*) FROM system"), &initialSystemCount)
// execute
tx, err := db.Begin()
if err != nil {
t.Fatalf(errors.Wrap(err, "beginning a transaction").Error())
}
if err := UpsertSystem(tx, tc.key, tc.val); err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, "executing for test case").Error())
}
tx.Commit()
// test
var key, val string
testutils.MustScan(t, "getting the saved record",
db.QueryRow("SELECT key, value FROM system WHERE key = ?", tc.key), &key, &val)
var systemCount int
testutils.MustScan(t, "counting records",
db.QueryRow("SELECT count(*) FROM system"), &systemCount)
testutils.AssertEqual(t, key, tc.key, "key mismatch")
testutils.AssertEqual(t, val, tc.val, "val mismatch")
testutils.AssertEqual(t, systemCount, initialSystemCount+tc.countDelta, "count mismatch")
})
}
}
func TestGetSystem(t *testing.T) {
t.Run(fmt.Sprintf("get string value"), func(t *testing.T) {
// Setup
ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true)
defer testutils.TeardownEnv(ctx)
// execute
db := ctx.DB
testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "foo", "bar")
tx, err := db.Begin()
if err != nil {
t.Fatalf(errors.Wrap(err, "beginning a transaction").Error())
}
var dest string
if err := GetSystem(tx, "foo", &dest); err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, "executing for test case").Error())
}
tx.Commit()
// test
testutils.AssertEqual(t, dest, "bar", "dest mismatch")
})
t.Run(fmt.Sprintf("get int64 value"), func(t *testing.T) {
// Setup
ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true)
defer testutils.TeardownEnv(ctx)
// execute
db := ctx.DB
testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "foo", 1234)
tx, err := db.Begin()
if err != nil {
t.Fatalf(errors.Wrap(err, "beginning a transaction").Error())
}
var dest int64
if err := GetSystem(tx, "foo", &dest); err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, "executing for test case").Error())
}
tx.Commit()
// test
testutils.AssertEqual(t, dest, int64(1234), "dest mismatch")
})
}
func TestUpdateSystem(t *testing.T) {
testCases := []struct {
key string
val string
countDelta int
}{
{
key: "foo",
val: "1558089284",
},
{
key: "foo",
val: "bar",
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("update %s %s", tc.key, tc.val), func(t *testing.T) {
// Setup
ctx := testutils.InitEnv(t, "../tmp", "../testutils/fixtures/schema.sql", true)
defer testutils.TeardownEnv(ctx)
db := ctx.DB
testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "foo", "fuz")
testutils.MustExec(t, "inserting a system configuration", db, "INSERT INTO system (key, value) VALUES (?, ?)", "baz", "quz")
var initialSystemCount int
testutils.MustScan(t, "counting records", db.QueryRow("SELECT count(*) FROM system"), &initialSystemCount)
// execute
tx, err := db.Begin()
if err != nil {
t.Fatalf(errors.Wrap(err, "beginning a transaction").Error())
}
if err := UpdateSystem(tx, tc.key, tc.val); err != nil {
tx.Rollback()
t.Fatalf(errors.Wrap(err, "executing for test case").Error())
}
tx.Commit()
// test
var key, val string
testutils.MustScan(t, "getting the saved record",
db.QueryRow("SELECT key, value FROM system WHERE key = ?", tc.key), &key, &val)
var systemCount int
testutils.MustScan(t, "counting records",
db.QueryRow("SELECT count(*) FROM system"), &systemCount)
testutils.AssertEqual(t, key, tc.key, "key mismatch")
testutils.AssertEqual(t, val, tc.val, "val mismatch")
testutils.AssertEqual(t, systemCount, initialSystemCount, "count mismatch")
})
}
}

107
crypt/utils.go Normal file
View file

@ -0,0 +1,107 @@
// Package crypt provides cryptographic funcitonalities
package crypt
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"io"
"github.com/dnote/cli/log"
"github.com/pkg/errors"
"golang.org/x/crypto/hkdf"
"golang.org/x/crypto/pbkdf2"
)
var aesGcmNonceSize = 12
func runHkdf(secret, salt, info []byte) ([]byte, error) {
r := hkdf.New(sha256.New, secret, salt, info)
ret := make([]byte, 32)
_, err := io.ReadFull(r, ret)
if err != nil {
return []byte{}, errors.Wrap(err, "reading key bytes")
}
return ret, nil
}
// MakeKeys derives, from the given credential, a key set comprising of an encryption key
// and an authentication key
func MakeKeys(password, email []byte, iteration int) ([]byte, []byte, error) {
masterKey := pbkdf2.Key([]byte(password), []byte(email), iteration, 32, sha256.New)
log.Debug("email: %s, password: %s", email, password)
authKey, err := runHkdf(masterKey, email, []byte("auth"))
if err != nil {
return nil, nil, errors.Wrap(err, "deriving auth key")
}
return masterKey, authKey, nil
}
// AesGcmEncrypt encrypts the plaintext using AES in a GCM mode. It returns
// a ciphertext prepended by a 12 byte pseudo-random nonce, encoded in base64.
func AesGcmEncrypt(key, plaintext []byte) (string, error) {
if key == nil {
return "", errors.New("no key provided")
}
block, err := aes.NewCipher(key)
if err != nil {
return "", errors.Wrap(err, "initializing aes")
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return "", errors.Wrap(err, "initializing gcm")
}
nonce := make([]byte, aesGcmNonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", errors.Wrap(err, "generating nonce")
}
ciphertext := aesgcm.Seal(nonce, nonce, []byte(plaintext), nil)
cipherKeyB64 := base64.StdEncoding.EncodeToString(ciphertext)
return cipherKeyB64, nil
}
// AesGcmDecrypt decrypts the encrypted data using AES in a GCM mode. The data should be
// a base64 encoded string in the format of 12 byte nonce followed by a ciphertext.
func AesGcmDecrypt(key []byte, dataB64 string) ([]byte, error) {
if key == nil {
return nil, errors.New("no key provided")
}
data, err := base64.StdEncoding.DecodeString(dataB64)
if err != nil {
return nil, errors.Wrap(err, "decoding base64 data")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, errors.Wrap(err, "initializing aes")
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, errors.Wrap(err, "initializing gcm")
}
if len(data) < aesGcmNonceSize {
return nil, errors.Wrap(err, "malformed data")
}
nonce, ciphertext := data[:aesGcmNonceSize], data[aesGcmNonceSize:]
plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, errors.Wrap(err, "decrypting")
}
return plaintext, nil
}

100
crypt/utils_test.go Normal file
View file

@ -0,0 +1,100 @@
package crypt
import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
"testing"
"github.com/dnote/cli/testutils"
"github.com/pkg/errors"
)
func TestAesGcmEncrypt(t *testing.T) {
testCases := []struct {
key []byte
plaintext []byte
}{
{
key: []byte("AES256Key-32Characters1234567890"),
plaintext: []byte("foo bar baz quz"),
},
{
key: []byte("AES256Key-32Charactersabcdefghij"),
plaintext: []byte("1234 foo 5678 bar 7890 baz"),
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("key %s plaintext %s", tc.key, tc.plaintext), func(t *testing.T) {
// encrypt
dataB64, err := AesGcmEncrypt(tc.key, tc.plaintext)
if err != nil {
t.Fatal(errors.Wrap(err, "performing encryption"))
}
// test that data can be decrypted
data, err := base64.StdEncoding.DecodeString(dataB64)
if err != nil {
t.Fatal(errors.Wrap(err, "decoding data from base64"))
}
nonce, ciphertext := data[:12], data[12:]
fmt.Println(string(data))
block, err := aes.NewCipher([]byte(tc.key))
if err != nil {
t.Fatal(errors.Wrap(err, "initializing aes"))
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
t.Fatal(errors.Wrap(err, "initializing gcm"))
}
plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
t.Fatal(errors.Wrap(err, "decode"))
}
testutils.AssertDeepEqual(t, plaintext, tc.plaintext, "plaintext mismatch")
})
}
}
func TestAesGcmDecrypt(t *testing.T) {
testCases := []struct {
key []byte
ciphertextB64 string
expectedPlaintext string
}{
{
key: []byte("AES256Key-32Characters1234567890"),
ciphertextB64: "M2ov9hWMQ52v1S/zigwX3bJt4cVCV02uiRm/grKqN/rZxNkJrD7vK4Ii0g==",
expectedPlaintext: "foo bar baz quz",
},
{
key: []byte("AES256Key-32Characters1234567890"),
ciphertextB64: "M4csFKUIUbD1FBEzLgHjscoKgN0lhMGJ0n2nKWiCkE/qSKlRP7kS",
expectedPlaintext: "foo\n1\nbar\n2",
},
{
key: []byte("AES256Key-32Characters1234567890"),
ciphertextB64: "pe/fnw73MR1clmVIlRSJ5gDwBdnPly/DF7DsR5dJVz4dHZlv0b10WzvJEGOCHZEr+Q==",
expectedPlaintext: "föo\nbār\nbåz & qūz",
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("key %s ciphertext %s", tc.key, tc.ciphertextB64), func(t *testing.T) {
plaintext, err := AesGcmDecrypt(tc.key, tc.ciphertextB64)
if err != nil {
t.Fatal(errors.Wrap(err, "performing decryption"))
}
testutils.AssertDeepEqual(t, plaintext, []byte(tc.expectedPlaintext), "plaintext mismatch")
})
}
}

View file

@ -3,6 +3,7 @@ package infra
import (
"database/sql"
"encoding/base64"
"fmt"
"os"
"os/user"
@ -27,21 +28,29 @@ var (
SystemLastMaxUSN = "last_max_usn"
// SystemLastUpgrade is the timestamp at which the system more recently checked for an upgrade
SystemLastUpgrade = "last_upgrade"
// SystemCipherKey is the encryption key
SystemCipherKey = "enc_key"
// SystemSessionKey is the session key
SystemSessionKey = "session_token"
// SystemSessionKeyExpiry is the timestamp at which the session key will expire
SystemSessionKeyExpiry = "session_token_expiry"
)
// DnoteCtx is a context holding the information of the current runtime
type DnoteCtx struct {
HomeDir string
DnoteDir string
APIEndpoint string
Version string
DB *sql.DB
HomeDir string
DnoteDir string
APIEndpoint string
Version string
DB *DB
SessionKey string
SessionKeyExpiry int64
CipherKey []byte
}
// Config holds dnote configuration
type Config struct {
Editor string
APIKey string
}
// NewCtx returns a new dnote context
@ -53,7 +62,7 @@ func NewCtx(apiEndpoint, versionTag string) (DnoteCtx, error) {
dnoteDir := getDnoteDir(homeDir)
dnoteDBPath := fmt.Sprintf("%s/dnote.db", dnoteDir)
db, err := sql.Open("sqlite3", dnoteDBPath)
db, err := OpenDB(dnoteDBPath)
if err != nil {
return DnoteCtx{}, errors.Wrap(err, "conntecting to db")
}
@ -69,6 +78,45 @@ func NewCtx(apiEndpoint, versionTag string) (DnoteCtx, error) {
return ret, nil
}
// SetupCtx populates context and returns a new context
func SetupCtx(ctx DnoteCtx) (DnoteCtx, error) {
db := ctx.DB
var sessionKey, cipherKeyB64 string
var sessionKeyExpiry int64
err := db.QueryRow("SELECT value FROM system WHERE key = ?", SystemSessionKey).Scan(&sessionKey)
if err != nil && err != sql.ErrNoRows {
return ctx, errors.Wrap(err, "finding sesison key")
}
err = db.QueryRow("SELECT value FROM system WHERE key = ?", SystemCipherKey).Scan(&cipherKeyB64)
if err != nil && err != sql.ErrNoRows {
return ctx, errors.Wrap(err, "finding sesison key")
}
err = db.QueryRow("SELECT value FROM system WHERE key = ?", SystemSessionKeyExpiry).Scan(&sessionKeyExpiry)
if err != nil && err != sql.ErrNoRows {
return ctx, errors.Wrap(err, "finding sesison key expiry")
}
cipherKey, err := base64.StdEncoding.DecodeString(cipherKeyB64)
if err != nil {
return ctx, errors.Wrap(err, "decoding cipherKey from base64")
}
ret := DnoteCtx{
HomeDir: ctx.HomeDir,
DnoteDir: ctx.DnoteDir,
APIEndpoint: ctx.APIEndpoint,
Version: ctx.Version,
DB: ctx.DB,
SessionKey: sessionKey,
SessionKeyExpiry: sessionKeyExpiry,
CipherKey: cipherKey,
}
return ret, nil
}
func getDnoteDir(homeDir string) string {
var ret string

120
infra/sql.go Normal file
View file

@ -0,0 +1,120 @@
package infra
import (
"database/sql"
"github.com/pkg/errors"
)
// SQLCommon is the minimal interface required by a db connection
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// sqlDb is an interface implemented by *sql.DB
type sqlDb interface {
Begin() (*sql.Tx, error)
}
// sqlTx is an interface implemented by *sql.Tx
type sqlTx interface {
Commit() error
Rollback() error
}
// DB contains information about the current database connection
type DB struct {
Conn SQLCommon
}
// OpenDB initializes a new connection to the sqlite database
func OpenDB(dbPath string) (*DB, error) {
dbConn, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, errors.Wrap(err, "opening db connection")
}
// Send a ping to ensure that the connection is established
// if err := dbConn.Ping(); err != nil {
// dbConn.Close()
// return nil, errors.Wrap(err, "ping")
// }
db := &DB{
Conn: dbConn,
}
return db, nil
}
// Begin begins a transaction
func (d *DB) Begin() (*DB, error) {
if db, ok := d.Conn.(sqlDb); ok && db != nil {
tx, err := db.Begin()
if err != nil {
return nil, err
}
return &DB{Conn: tx}, nil
}
return nil, errors.New("can't start transaction")
}
// Commit commits a transaction
func (d *DB) Commit() error {
if db, ok := d.Conn.(sqlTx); ok && db != nil {
if err := db.Commit(); err != nil {
return err
}
}
return errors.New("invalid transaction")
}
// Rollback rolls back a transaction
func (d *DB) Rollback() error {
if db, ok := d.Conn.(sqlTx); ok && db != nil {
if err := db.Rollback(); err != nil {
return err
}
}
return errors.New("invalid transaction")
}
// Exec executes a sql
func (d *DB) Exec(query string, values ...interface{}) (sql.Result, error) {
return d.Conn.Exec(query, values...)
}
// Prepare prepares a sql
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
return d.Conn.Prepare(query)
}
// Query queries rows
func (d *DB) Query(query string, values ...interface{}) (*sql.Rows, error) {
return d.Conn.Query(query, values...)
}
// QueryRow queries a row
func (d *DB) QueryRow(query string, values ...interface{}) *sql.Row {
return d.Conn.QueryRow(query, values...)
}
type closer interface {
Close() error
}
// Close closes a db connection
func (d *DB) Close() error {
if db, ok := d.Conn.(closer); ok {
return db.Close()
}
return errors.New("can't close db")
}

View file

@ -16,14 +16,14 @@ var (
// ColorBlue is a blue foreground color
ColorBlue = color.New(color.FgBlue)
// ColorGray is a gray foreground color
ColorGray = color.New(color.FgWhite)
ColorGray = color.New(color.FgHiBlack)
)
var indent = " "
// Info prints information
func Info(msg string) {
fmt.Fprintf(color.Output, "%s%s %s\n", indent, ColorBlue.Sprint("•"), msg)
fmt.Fprintf(color.Output, "%s%s %s", indent, ColorBlue.Sprint("•"), msg)
}
// Infof prints information with optional format verbs
@ -58,7 +58,7 @@ func Warnf(msg string, v ...interface{}) {
// Error prints an error message
func Error(msg string) {
fmt.Fprintf(color.Output, "%s%s %s\n", indent, ColorRed.Sprint(""), msg)
fmt.Fprintf(color.Output, "%s%s %s", indent, ColorRed.Sprint(""), msg)
}
// Errorf prints an error message with optional format verbs
@ -71,6 +71,21 @@ func Printf(msg string, v ...interface{}) {
fmt.Fprintf(color.Output, "%s%s %s", indent, ColorGray.Sprint("•"), fmt.Sprintf(msg, v...))
}
// Askf prints an question with optional format verbs. The leading symbol differs in color depending
// on whether the input is masked.
func Askf(msg string, masked bool, v ...interface{}) {
symbolChar := "[?]"
var symbol string
if masked {
symbol = ColorGray.Sprintf(symbolChar)
} else {
symbol = ColorGreen.Sprintf(symbolChar)
}
fmt.Fprintf(color.Output, "%s%s %s: ", indent, symbol, fmt.Sprintf(msg, v...))
}
// Debug prints to the console if DNOTE_DEBUG is set
func Debug(msg string, v ...interface{}) {
if os.Getenv("DNOTE_DEBUG") == "1" {

View file

@ -15,6 +15,7 @@ import (
"github.com/dnote/cli/cmd/edit"
"github.com/dnote/cli/cmd/find"
"github.com/dnote/cli/cmd/login"
"github.com/dnote/cli/cmd/logout"
"github.com/dnote/cli/cmd/ls"
"github.com/dnote/cli/cmd/remove"
"github.com/dnote/cli/cmd/sync"
@ -37,9 +38,15 @@ func main() {
panic(errors.Wrap(err, "preparing dnote run"))
}
ctx, err = infra.SetupCtx(ctx)
if err != nil {
panic(errors.Wrap(err, "setting up context"))
}
root.Register(remove.NewCmd(ctx))
root.Register(edit.NewCmd(ctx))
root.Register(login.NewCmd(ctx))
root.Register(logout.NewCmd(ctx))
root.Register(add.NewCmd(ctx))
root.Register(ls.NewCmd(ctx))
root.Register(sync.NewCmd(ctx))

View file

@ -1,7 +1,6 @@
package migrate
import (
"database/sql"
"encoding/json"
"fmt"
"net/http"
@ -39,13 +38,13 @@ func TestExecute_bump_schema(t *testing.T) {
m1 := migration{
name: "noop",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
return nil
},
}
m2 := migration{
name: "noop",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
return nil
},
}
@ -97,28 +96,28 @@ func TestRun_nonfresh(t *testing.T) {
sequence := []migration{
migration{
name: "v1",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v1 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v1")
return nil
},
},
migration{
name: "v2",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v2 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v2")
return nil
},
},
migration{
name: "v3",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v3 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v3")
return nil
},
},
migration{
name: "v4",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v4 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v4")
return nil
},
@ -176,21 +175,21 @@ func TestRun_fresh(t *testing.T) {
sequence := []migration{
migration{
name: "v1",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v1 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v1")
return nil
},
},
migration{
name: "v2",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v2 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v2")
return nil
},
},
migration{
name: "v3",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v3 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v3")
return nil
},
@ -250,21 +249,21 @@ func TestRun_up_to_date(t *testing.T) {
sequence := []migration{
migration{
name: "v1",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v1 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v1")
return nil
},
},
migration{
name: "v2",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v2 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v2")
return nil
},
},
migration{
name: "v3",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, db *infra.DB) error {
testutils.MustExec(t, "marking v3 completed", db, "INSERT INTO migrate_run_test (name) VALUES (?)", "v3")
return nil
},
@ -876,6 +875,7 @@ func TestLocalMigration9(t *testing.T) {
func TestRemoteMigration1(t *testing.T) {
// set up
ctx := testutils.InitEnv(t, "../tmp", "./fixtures/remote-1-pre-schema.sql", false)
testutils.Login(t, &ctx)
defer testutils.TeardownEnv(ctx)
JSBookUUID := "existing-js-book-uuid"
@ -914,14 +914,13 @@ func TestRemoteMigration1(t *testing.T) {
ctx.APIEndpoint = server.URL
confStr := fmt.Sprintf("apikey: mock_api_key")
testutils.WriteFile(ctx, []byte(confStr), "dnoterc")
db := ctx.DB
testutils.MustExec(t, "inserting js book", db, "INSERT INTO books (uuid, label) VALUES (?, ?)", JSBookUUID, "js")
testutils.MustExec(t, "inserting css book", db, "INSERT INTO books (uuid, label) VALUES (?, ?)", CSSBookUUID, "css")
testutils.MustExec(t, "inserting linux book", db, "INSERT INTO books (uuid, label) VALUES (?, ?)", linuxBookUUID, "linux")
testutils.MustExec(t, "inserting sessionKey", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemSessionKey, "someSessionKey")
testutils.MustExec(t, "inserting sessionKeyExpiry", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemSessionKeyExpiry, time.Now().Add(24*time.Hour).Unix())
tx, err := db.Begin()
if err != nil {

View file

@ -7,7 +7,6 @@ import (
"github.com/dnote/actions"
"github.com/dnote/cli/client"
"github.com/dnote/cli/core"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/log"
"github.com/pkg/errors"
@ -15,12 +14,12 @@ import (
type migration struct {
name string
run func(ctx infra.DnoteCtx, tx *sql.Tx) error
run func(ctx infra.DnoteCtx, tx *infra.DB) error
}
var lm1 = migration{
name: "upgrade-edit-note-from-v1-to-v3",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
rows, err := tx.Query("SELECT uuid, data FROM actions WHERE type = ? AND schema = ?", "edit_note", 1)
if err != nil {
return errors.Wrap(err, "querying rows")
@ -68,7 +67,7 @@ var lm1 = migration{
var lm2 = migration{
name: "upgrade-edit-note-from-v2-to-v3",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
rows, err := tx.Query("SELECT uuid, data FROM actions WHERE type = ? AND schema = ?", "edit_note", 2)
if err != nil {
return errors.Wrap(err, "querying rows")
@ -113,7 +112,7 @@ var lm2 = migration{
var lm3 = migration{
name: "upgrade-remove-note-from-v1-to-v2",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
rows, err := tx.Query("SELECT uuid, data FROM actions WHERE type = ? AND schema = ?", "remove_note", 1)
if err != nil {
return errors.Wrap(err, "querying rows")
@ -155,7 +154,7 @@ var lm3 = migration{
var lm4 = migration{
name: "add-dirty-usn-deleted-to-notes-and-books",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
_, err := tx.Exec("ALTER TABLE books ADD COLUMN dirty bool DEFAULT false")
if err != nil {
return errors.Wrap(err, "adding dirty column to books")
@ -192,7 +191,7 @@ var lm4 = migration{
var lm5 = migration{
name: "mark-action-targets-dirty",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
rows, err := tx.Query("SELECT uuid, data, type FROM actions")
if err != nil {
return errors.Wrap(err, "querying rows")
@ -254,7 +253,7 @@ var lm5 = migration{
var lm6 = migration{
name: "drop-actions",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
_, err := tx.Exec("DROP TABLE actions;")
if err != nil {
return errors.Wrap(err, "dropping the actions table")
@ -266,7 +265,7 @@ var lm6 = migration{
var lm7 = migration{
name: "resolve-conflicts-with-reserved-book-names",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
migrateBook := func(name string) error {
var uuid string
@ -313,7 +312,7 @@ var lm7 = migration{
var lm8 = migration{
name: "drop-note-id-and-rename-content-to-body",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
_, err := tx.Exec(`CREATE TABLE notes_tmp
(
uuid text NOT NULL,
@ -352,7 +351,7 @@ var lm8 = migration{
var lm9 = migration{
name: "create-fts-index",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
_, err := tx.Exec(`CREATE VIRTUAL TABLE IF NOT EXISTS note_fts USING fts5(content=notes, body, tokenize="porter unicode61 categories 'L* N* Co Ps Pe'");`)
if err != nil {
return errors.Wrap(err, "creating note_fts")
@ -388,16 +387,13 @@ var lm9 = migration{
var rm1 = migration{
name: "sync-book-uuids-from-server",
run: func(ctx infra.DnoteCtx, tx *sql.Tx) error {
config, err := core.ReadConfig(ctx)
if err != nil {
return errors.Wrap(err, "reading the config")
}
if config.APIKey == "" {
return errors.New("login required")
run: func(ctx infra.DnoteCtx, tx *infra.DB) error {
sessionKey := ctx.SessionKey
if sessionKey == "" {
return errors.New("not logged in")
}
resp, err := client.GetBooks(ctx, config.APIKey)
resp, err := client.GetBooks(ctx, sessionKey)
if err != nil {
return errors.Wrap(err, "getting books from the server")
}

View file

@ -1,12 +1,17 @@
#!/bin/bash
# run_server_test.sh runs server test files sequentially
# https://stackoverflow.com/questions/23715302/go-how-to-run-tests-for-multiple-packages
set -e
# clear tmp dir in case not properly torn down
rm -rf ./tmp
rm -rf $GOPATH/src/github.com/dnote/cli/tmp
# run test
pushd $GOPATH/src/github.com/dnote/cli
go test ./... \
-p 1\
--tags "fts5"
popd

View file

@ -14,6 +14,7 @@ import (
"reflect"
"strings"
"testing"
"time"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/utils"
@ -21,13 +22,8 @@ import (
)
// InitEnv sets up a test env and returns a new dnote context
func InitEnv(t *testing.T, relPath string, relFixturePath string, migrated bool) infra.DnoteCtx {
path, err := filepath.Abs(relPath)
if err != nil {
t.Fatal(errors.Wrap(err, "pasrsing path"))
}
os.Setenv("DNOTE_HOME_DIR", path)
func InitEnv(t *testing.T, dnotehomePath string, fixturePath string, migrated bool) infra.DnoteCtx {
os.Setenv("DNOTE_HOME_DIR", dnotehomePath)
ctx, err := infra.NewCtx("", "")
if err != nil {
t.Fatal(errors.Wrap(err, "getting new ctx"))
@ -39,7 +35,7 @@ func InitEnv(t *testing.T, relPath string, relFixturePath string, migrated bool)
}
// set up db
b := ReadFileAbs(relFixturePath)
b := ReadFileAbs(fixturePath)
setupSQL := string(b)
db := ctx.DB
@ -61,6 +57,19 @@ func InitEnv(t *testing.T, relPath string, relFixturePath string, migrated bool)
return ctx
}
// Login simulates a logged in user by inserting credentials in the local database
func Login(t *testing.T, ctx *infra.DnoteCtx) {
db := ctx.DB
MustExec(t, "inserting sessionKey", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemSessionKey, "someSessionKey")
MustExec(t, "inserting sessionKeyExpiry", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemSessionKeyExpiry, time.Now().Add(24*time.Hour).Unix())
MustExec(t, "inserting cipherKey", db, "INSERT INTO system (key, value) VALUES (?, ?)", infra.SystemCipherKey, "QUVTMjU2S2V5LTMyQ2hhcmFjdGVyczEyMzQ1Njc4OTA=")
ctx.SessionKey = "someSessionKey"
ctx.SessionKeyExpiry = time.Now().Add(24 * time.Hour).Unix()
ctx.CipherKey = []byte("AES256Key-32Characters1234567890")
}
// TeardownEnv cleans up the test env represented by the given context
func TeardownEnv(ctx infra.DnoteCtx) {
ctx.DB.Close()
@ -210,7 +219,7 @@ func IsEqualJSON(s1, s2 []byte) (bool, error) {
}
// MustExec executes the given SQL query and fails a test if an error occurs
func MustExec(t *testing.T, message string, db *sql.DB, query string, args ...interface{}) sql.Result {
func MustExec(t *testing.T, message string, db *infra.DB, query string, args ...interface{}) sql.Result {
result, err := db.Exec(query, args...)
if err != nil {
t.Fatal(errors.Wrap(errors.Wrap(err, "executing sql"), message))
@ -309,7 +318,7 @@ func WaitDnoteCmd(t *testing.T, ctx infra.DnoteCtx, runFunc func(io.WriteCloser)
func UserConfirm(stdin io.WriteCloser) error {
// confirm
if _, err := io.WriteString(stdin, "y\n"); err != nil {
return errors.Wrap(err, "confirming deletion")
return errors.Wrap(err, "indicating confirmation in stdin")
}
return nil

View file

@ -9,11 +9,13 @@ import (
"os"
"path/filepath"
"strings"
"syscall"
"github.com/dnote/cli/infra"
"github.com/dnote/cli/log"
"github.com/pkg/errors"
"github.com/satori/go.uuid"
"golang.org/x/crypto/ssh/terminal"
)
// GenerateUUID returns a uid
@ -28,7 +30,38 @@ func getInput() (string, error) {
return "", errors.Wrap(err, "reading stdin")
}
return input, nil
return strings.Trim(input, "\r\n"), nil
}
// PromptInput prompts the user input and saves the result to the destination
func PromptInput(message string, dest *string) error {
log.Askf(message, false)
input, err := getInput()
if err != nil {
return errors.Wrap(err, "getting user input")
}
*dest = input
return nil
}
// PromptPassword prompts the user input a password and saves the result to the destination.
// The input is masked, meaning it is not echoed on the terminal.
func PromptPassword(message string, dest *string) error {
log.Askf(message, true)
password, err := terminal.ReadPassword(syscall.Stdin)
if err != nil {
return errors.Wrap(err, "getting user input")
}
fmt.Println("")
*dest = string(password)
return nil
}
// AskConfirmation prompts for user input to confirm a choice
@ -40,17 +73,17 @@ func AskConfirmation(question string, optimistic bool) (bool, error) {
choices = "(y/N)"
}
log.Printf("%s %s: ", question, choices)
message := fmt.Sprintf("%s %s", question, choices)
res, err := getInput()
if err != nil {
var input string
if err := PromptInput(message, &input); err != nil {
return false, errors.Wrap(err, "Failed to get user input")
}
confirmed := res == "y\n" || res == "y\r\n"
confirmed := input == "y"
if optimistic {
confirmed = confirmed || res == "\n" || res == "\r\n"
confirmed = confirmed || input == ""
}
return confirmed, nil
@ -109,18 +142,48 @@ func CopyDir(src, dest string) error {
return nil
}
// DoAuthorizedReq does a http request to the given path in the api endpoint as a user,
// with the appropriate headers. The given path should include the preceding slash.
func DoAuthorizedReq(ctx infra.DnoteCtx, apiKey, method, path, body string) (*http.Response, error) {
func getReq(ctx infra.DnoteCtx, path, method, body string) (*http.Request, error) {
endpoint := fmt.Sprintf("%s%s", ctx.APIEndpoint, path)
req, err := http.NewRequest(method, endpoint, strings.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "constructing http request")
}
req.Header.Set("Authorization", apiKey)
req.Header.Set("CLI-Version", ctx.Version)
return req, nil
}
// DoAuthorizedReq does a http request to the given path in the api endpoint as a user,
// with the appropriate headers. The given path should include the preceding slash.
func DoAuthorizedReq(ctx infra.DnoteCtx, hc http.Client, method, path, body string) (*http.Response, error) {
if ctx.SessionKey == "" {
return nil, errors.New("no session key found")
}
req, err := getReq(ctx, path, method, body)
if err != nil {
return nil, errors.Wrap(err, "getting request")
}
credential := fmt.Sprintf("Bearer %s", ctx.SessionKey)
req.Header.Set("Authorization", credential)
res, err := hc.Do(req)
if err != nil {
return res, errors.Wrap(err, "making http request")
}
return res, nil
}
// DoReq does a http request to the given path in the api endpoint
func DoReq(ctx infra.DnoteCtx, method, path, body string) (*http.Response, error) {
req, err := getReq(ctx, path, method, body)
if err != nil {
return nil, errors.Wrap(err, "getting request")
}
hc := http.Client{}
res, err := hc.Do(req)
if err != nil {