diff --git a/pkg/server/database/database.go b/pkg/server/database/database.go
index 1bb9e159..d2d07992 100644
--- a/pkg/server/database/database.go
+++ b/pkg/server/database/database.go
@@ -19,11 +19,7 @@
package database
import (
- "fmt"
- "os"
-
"github.com/jinzhu/gorm"
- "github.com/pkg/errors"
// Use postgres
_ "github.com/lib/pq"
@@ -34,97 +30,13 @@ var (
MigrationTableName = "migrations"
)
-// Config holds the connection configuration
-type Config struct {
- Host string
- Port string
- Name string
- User string
- Password string
-}
-
-// ErrConfigMissingHost is an error for an incomplete configuration missing the host
-var ErrConfigMissingHost = errors.New("Host is empty")
-
-// ErrConfigMissingPort is an error for an incomplete configuration missing the port
-var ErrConfigMissingPort = errors.New("Port is empty")
-
-// ErrConfigMissingName is an error for an incomplete configuration missing the name
-var ErrConfigMissingName = errors.New("Name is empty")
-
-// ErrConfigMissingUser is an error for an incomplete configuration missing the user
-var ErrConfigMissingUser = errors.New("User is empty")
-
-func validateConfig(c Config) error {
- if c.Host == "" {
- return ErrConfigMissingHost
- }
- if c.Port == "" {
- return ErrConfigMissingPort
- }
- if c.Name == "" {
- return ErrConfigMissingName
- }
- if c.User == "" {
- return ErrConfigMissingUser
- }
-
- return nil
-}
-
-func getPGConnectionString(c Config) (string, error) {
- if err := validateConfig(c); err != nil {
- return "", errors.Wrap(err, "invalid database config")
- }
-
- var sslmode string
- if os.Getenv("GO_ENV") == "PRODUCTION" && os.Getenv("DB_NOSSL") == "" {
- sslmode = "require"
- } else {
- sslmode = "disable"
- }
-
- return fmt.Sprintf(
- "sslmode=%s host=%s port=%s dbname=%s user=%s password=%s",
- sslmode,
- c.Host,
- c.Port,
- c.Name,
- c.User,
- c.Password,
- ), nil
-}
-
-var (
- // DBConn is the connection handle for the database
- DBConn *gorm.DB
-)
-
-// Open opens the connection with the database
-func Open(c Config) {
- connStr, err := getPGConnectionString(c)
- if err != nil {
- panic(err)
- }
-
- DBConn, err = gorm.Open("postgres", connStr)
- if err != nil {
- panic(err)
- }
-}
-
-// Close closes database connection
-func Close() {
- DBConn.Close()
-}
-
// InitSchema migrates database schema to reflect the latest model definition
-func InitSchema() {
- if err := DBConn.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil {
+func InitSchema(db *gorm.DB) {
+ if err := db.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil {
panic(err)
}
- if err := DBConn.AutoMigrate(
+ if err := db.AutoMigrate(
Note{},
Book{},
User{},
diff --git a/pkg/server/database/migrate.go b/pkg/server/database/migrate.go
index 9efbf610..217c0c1f 100644
--- a/pkg/server/database/migrate.go
+++ b/pkg/server/database/migrate.go
@@ -22,20 +22,20 @@ import (
"log"
"github.com/gobuffalo/packr/v2"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/rubenv/sql-migrate"
)
// Migrate runs the migrations
-func Migrate() error {
+func Migrate(db *gorm.DB) error {
migrations := &migrate.PackrMigrationSource{
Box: packr.New("migrations", "../database/migrations/"),
}
migrate.SetTable(MigrationTableName)
- db := DBConn.DB()
- n, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
+ n, err := migrate.Exec(db.DB(), "postgres", migrations, migrate.Up)
if err != nil {
return errors.Wrap(err, "running migrations")
}
diff --git a/pkg/server/database/migrate/main.go b/pkg/server/database/migrate/main.go
index 06d5b705..9d3e4805 100644
--- a/pkg/server/database/migrate/main.go
+++ b/pkg/server/database/migrate/main.go
@@ -23,7 +23,7 @@ import (
"fmt"
"os"
- "github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/dbconn"
"github.com/joho/godotenv"
"github.com/pkg/errors"
"github.com/rubenv/sql-migrate"
@@ -43,20 +43,18 @@ func init() {
}
}
- c := database.Config{
- Host: os.Getenv("DBHost"),
- Port: os.Getenv("DBPort"),
- Name: os.Getenv("DBName"),
- User: os.Getenv("DBUser"),
- Password: os.Getenv("DBPassword"),
- }
- database.Open(c)
}
func main() {
flag.Parse()
- db := database.DBConn
+ db := dbconn.Open(dbconn.Config{
+ Host: os.Getenv("DBHost"),
+ Port: os.Getenv("DBPort"),
+ Name: os.Getenv("DBName"),
+ User: os.Getenv("DBUser"),
+ Password: os.Getenv("DBPassword"),
+ })
migrations := &migrate.FileMigrationSource{
Dir: *migrationDir,
diff --git a/pkg/server/dbconn/dbconn.go b/pkg/server/dbconn/dbconn.go
new file mode 100644
index 00000000..d55d1400
--- /dev/null
+++ b/pkg/server/dbconn/dbconn.go
@@ -0,0 +1,85 @@
+package dbconn
+
+import (
+ "fmt"
+
+ "github.com/jinzhu/gorm"
+ "github.com/pkg/errors"
+)
+
+// Config holds the connection configuration
+type Config struct {
+ SkipSSL bool
+ Host string
+ Port string
+ Name string
+ User string
+ Password string
+}
+
+// ErrConfigMissingHost is an error for an incomplete configuration missing the host
+var ErrConfigMissingHost = errors.New("Host is empty")
+
+// ErrConfigMissingPort is an error for an incomplete configuration missing the port
+var ErrConfigMissingPort = errors.New("Port is empty")
+
+// ErrConfigMissingName is an error for an incomplete configuration missing the name
+var ErrConfigMissingName = errors.New("Name is empty")
+
+// ErrConfigMissingUser is an error for an incomplete configuration missing the user
+var ErrConfigMissingUser = errors.New("User is empty")
+
+func validateConfig(c Config) error {
+ if c.Host == "" {
+ return ErrConfigMissingHost
+ }
+ if c.Port == "" {
+ return ErrConfigMissingPort
+ }
+ if c.Name == "" {
+ return ErrConfigMissingName
+ }
+ if c.User == "" {
+ return ErrConfigMissingUser
+ }
+
+ return nil
+}
+
+func getPGConnectionString(c Config) (string, error) {
+ if err := validateConfig(c); err != nil {
+ return "", errors.Wrap(err, "invalid database config")
+ }
+
+ var sslmode string
+ if c.SkipSSL {
+ sslmode = "disable"
+ } else {
+ sslmode = "require"
+ }
+
+ return fmt.Sprintf(
+ "sslmode=%s host=%s port=%s dbname=%s user=%s password=%s",
+ sslmode,
+ c.Host,
+ c.Port,
+ c.Name,
+ c.User,
+ c.Password,
+ ), nil
+}
+
+// Open opens the connection with the database
+func Open(c Config) *gorm.DB {
+ connStr, err := getPGConnectionString(c)
+ if err != nil {
+ panic(errors.Wrap(err, "getting connection string"))
+ }
+
+ conn, err := gorm.Open("postgres", connStr)
+ if err != nil {
+ panic(errors.Wrap(err, "opening database connection"))
+ }
+
+ return conn
+}
diff --git a/pkg/server/database/database_test.go b/pkg/server/dbconn/dbconn_test.go
similarity index 90%
rename from pkg/server/database/database_test.go
rename to pkg/server/dbconn/dbconn_test.go
index 2079acf1..44d55b7e 100644
--- a/pkg/server/database/database_test.go
+++ b/pkg/server/dbconn/dbconn_test.go
@@ -16,7 +16,7 @@
* along with Dnote. If not, see .
*/
-package database
+package dbconn
import (
"github.com/dnote/dnote/pkg/assert"
@@ -38,6 +38,17 @@ func TestValidateConfig(t *testing.T) {
},
expected: nil,
},
+ {
+ input: Config{
+ SkipSSL: true,
+ Host: "mockHost",
+ Port: "mockPort",
+ Name: "mockName",
+ User: "mockUser",
+ Password: "mockPassword",
+ },
+ expected: nil,
+ },
{
input: Config{
Host: "mockHost",
diff --git a/pkg/server/handlers/auth.go b/pkg/server/handlers/auth.go
index 32ed6da4..f11dbecc 100644
--- a/pkg/server/handlers/auth.go
+++ b/pkg/server/handlers/auth.go
@@ -24,10 +24,10 @@ import (
"net/http"
"time"
- "github.com/dnote/dnote/pkg/server/helpers"
- "github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/mailer"
+ "github.com/dnote/dnote/pkg/server/operations"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
@@ -60,10 +60,8 @@ func (a *App) getMe(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
-
var account database.Account
- if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
+ if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
@@ -76,7 +74,7 @@ func (a *App) getMe(w http.ResponseWriter, r *http.Request) {
User: session,
}
- tx := db.Begin()
+ tx := a.DB.Begin()
if err := operations.TouchLastLoginAt(user, tx); err != nil {
tx.Rollback()
// In case of an error, gracefully continue to avoid disturbing the service
@@ -92,8 +90,6 @@ type createResetTokenPayload struct {
}
func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
var params createResetTokenPayload
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
http.Error(w, "invalid payload", http.StatusBadRequest)
@@ -101,7 +97,7 @@ func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- conn := db.Where("email = ?", params.Email).First(&account)
+ conn := a.DB.Where("email = ?", params.Email).First(&account)
if conn.RecordNotFound() {
return
}
@@ -127,7 +123,7 @@ func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
Type: database.TokenTypeResetPassword,
}
- if err := db.Save(&token).Error; err != nil {
+ if err := a.DB.Save(&token).Error; err != nil {
HandleError(w, errors.Wrap(err, "saving token").Error(), nil, http.StatusInternalServerError)
return
}
@@ -156,8 +152,6 @@ type resetPasswordPayload struct {
}
func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
var params resetPasswordPayload
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
http.Error(w, "invalid payload", http.StatusBadRequest)
@@ -165,7 +159,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
}
var token database.Token
- conn := db.Where("value = ? AND type =? AND used_at IS NULL", params.Token, database.TokenTypeResetPassword).First(&token)
+ conn := a.DB.Where("value = ? AND type =? AND used_at IS NULL", params.Token, database.TokenTypeResetPassword).First(&token)
if conn.RecordNotFound() {
http.Error(w, "invalid token", http.StatusBadRequest)
return
@@ -186,7 +180,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
return
}
- tx := db.Begin()
+ tx := a.DB.Begin()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(params.Password), bcrypt.DefaultCost)
if err != nil {
@@ -196,7 +190,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- if err := db.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
+ if err := a.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
tx.Rollback()
HandleError(w, errors.Wrap(err, "finding user").Error(), nil, http.StatusInternalServerError)
return
@@ -216,10 +210,10 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
tx.Commit()
var user database.User
- if err := db.Where("id = ?", account.UserID).First(&user).Error; err != nil {
+ if err := a.DB.Where("id = ?", account.UserID).First(&user).Error; err != nil {
HandleError(w, errors.Wrap(err, "finding user").Error(), nil, http.StatusInternalServerError)
return
}
- respondWithSession(w, user.ID, http.StatusOK)
+ respondWithSession(a.DB, w, user.ID, http.StatusOK)
}
diff --git a/pkg/server/handlers/auth_test.go b/pkg/server/handlers/auth_test.go
index 12cb10fe..4e563d4f 100644
--- a/pkg/server/handlers/auth_test.go
+++ b/pkg/server/handlers/auth_test.go
@@ -31,8 +31,8 @@ import (
)
func TestGetMe(t *testing.T) {
+ testutils.InitTestDB()
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
@@ -41,7 +41,7 @@ func TestGetMe(t *testing.T) {
defer server.Close()
u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ testutils.SetupAccountData( u, "alice@example.com", "somepassword")
dat := `{"email": "alice@example.com"}`
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
@@ -53,23 +53,24 @@ func TestGetMe(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
var user database.User
- testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
}
func TestCreateResetToken(t *testing.T) {
t.Run("success", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ testutils.SetupAccountData( u, "alice@example.com", "somepassword")
dat := `{"email": "alice@example.com"}`
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
@@ -81,10 +82,10 @@ func TestCreateResetToken(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
var tokenCount int
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
var resetToken database.Token
- testutils.MustExec(t, db.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token")
assert.Equal(t, tokenCount, 1, "reset_token count mismatch")
assert.NotEqual(t, resetToken.Value, nil, "reset_token value mismatch")
@@ -92,17 +93,18 @@ func TestCreateResetToken(t *testing.T) {
})
t.Run("nonexistent email", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ testutils.SetupAccountData( u, "alice@example.com", "somepassword")
dat := `{"email": "bob@example.com"}`
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
@@ -114,36 +116,37 @@ func TestCreateResetToken(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
var tokenCount int
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
assert.Equal(t, tokenCount, 0, "reset_token count mismatch")
})
}
func TestResetPassword(t *testing.T) {
t.Run("success", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword")
+ a := testutils.SetupAccountData( u, "alice@example.com", "oldpassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
otherTok := database.Token{
UserID: u.ID,
Value: "somerandomvalue",
Type: database.TokenTypeEmailVerification,
}
- testutils.MustExec(t, db.Save(&otherTok), "preparing another token")
+ testutils.MustExec(t, testutils.DB.Save(&otherTok), "preparing another token")
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "newpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@@ -156,9 +159,9 @@ func TestResetPassword(t *testing.T) {
var resetToken, verificationToken database.Token
var account database.Account
- testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
- testutils.MustExec(t, db.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token")
- testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
+ testutils.MustExec(t, testutils.DB.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account")
assert.NotEqual(t, resetToken.UsedAt, nil, "reset_token UsedAt mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
@@ -167,23 +170,24 @@ func TestResetPassword(t *testing.T) {
})
t.Run("nonexistent token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"token": "-ApMnyvpg59uOU5b-Kf5uQ==", "password": "oldpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@@ -196,8 +200,8 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
- testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
- testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
@@ -205,24 +209,25 @@ func TestResetPassword(t *testing.T) {
})
t.Run("expired token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
- testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@@ -235,24 +240,25 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
- testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
- testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
+ testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
})
t.Run("used token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
usedAt := time.Now().Add(time.Hour * -11).UTC()
tok := database.Token{
@@ -261,8 +267,8 @@ func TestResetPassword(t *testing.T) {
Type: database.TokenTypeResetPassword,
UsedAt: &usedAt,
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
- testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@@ -275,8 +281,8 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
- testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
- testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
+ testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
if resetToken.UsedAt.Year() != usedAt.Year() ||
@@ -290,24 +296,25 @@ func TestResetPassword(t *testing.T) {
})
t.Run("using wrong type token: email_verification", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeEmailVerification,
}
- testutils.MustExec(t, db.Save(&tok), "Failed to prepare reset_token")
- testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "Failed to prepare reset_token")
+ testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@@ -320,8 +327,8 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
- testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
- testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
+ testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
diff --git a/pkg/server/handlers/classic.go b/pkg/server/handlers/classic.go
index 25aaaaa7..0f1b0c98 100644
--- a/pkg/server/handlers/classic.go
+++ b/pkg/server/handlers/classic.go
@@ -23,11 +23,11 @@ import (
"net/http"
"github.com/dnote/dnote/pkg/server/crypt"
+ "github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
+ "github.com/dnote/dnote/pkg/server/log"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/presenters"
- "github.com/dnote/dnote/pkg/server/database"
- "github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
@@ -39,15 +39,13 @@ func (a *App) classicMigrate(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
-
var account database.Account
- if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
+ if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
- if err := db.Model(&account).
+ if err := a.DB.Model(&account).
Update(map[string]interface{}{
"salt": "",
"auth_key_hash": "",
@@ -66,8 +64,6 @@ type PresigninResponse struct {
}
func (a *App) classicPresignin(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
q := r.URL.Query()
email := q.Get("email")
if email == "" {
@@ -76,7 +72,7 @@ func (a *App) classicPresignin(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- conn := db.Where("email = ?", email).First(&account)
+ conn := a.DB.Where("email = ?", email).First(&account)
if !conn.RecordNotFound() && conn.Error != nil {
HandleError(w, "getting user", conn.Error, http.StatusInternalServerError)
return
@@ -106,8 +102,6 @@ type classicSigninPayload struct {
}
func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
var params classicSigninPayload
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
@@ -120,7 +114,7 @@ func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- conn := db.Where("email = ?", params.Email).First(&account)
+ conn := a.DB.Where("email = ?", params.Email).First(&account)
if conn.RecordNotFound() {
http.Error(w, ErrLoginFailure.Error(), http.StatusUnauthorized)
return
@@ -138,7 +132,7 @@ func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
return
}
- session, err := operations.CreateSession(db, account.UserID)
+ session, err := operations.CreateSession(a.DB, account.UserID)
if err != nil {
HandleError(w, "creating session", nil, http.StatusBadRequest)
return
@@ -169,10 +163,8 @@ func (a *App) classicGetMe(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
-
var account database.Account
- if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
+ if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
@@ -229,8 +221,6 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
-
var params classicSetPasswordPayload
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
@@ -238,7 +228,7 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
+ if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "getting user", nil, http.StatusInternalServerError)
return
}
@@ -249,7 +239,7 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
return
}
- if err := db.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
+ if err := a.DB.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
http.Error(w, errors.Wrap(err, "updating password").Error(), http.StatusInternalServerError)
return
}
@@ -265,8 +255,7 @@ func (a *App) classicGetNotes(w http.ResponseWriter, r *http.Request) {
}
var notes []database.Note
- db := database.DBConn
- if err := db.Where("user_id = ? AND encrypted = true", user.ID).Find(¬es).Error; err != nil {
+ if err := a.DB.Where("user_id = ? AND encrypted = true", user.ID).Find(¬es).Error; err != nil {
HandleError(w, "finding notes", err, http.StatusInternalServerError)
return
}
diff --git a/pkg/server/handlers/classic_test.go b/pkg/server/handlers/classic_test.go
index 004a1edd..0a01fc59 100644
--- a/pkg/server/handlers/classic_test.go
+++ b/pkg/server/handlers/classic_test.go
@@ -22,26 +22,16 @@ import (
"encoding/json"
"fmt"
"net/http"
- "os"
"testing"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/database"
- "github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-
- templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
- mailer.InitTemplates(&templatePath)
-}
-
func TestClassicPresignin(t *testing.T) {
- db := database.DBConn
defer testutils.ClearData()
alice := database.Account{
@@ -52,8 +42,8 @@ func TestClassicPresignin(t *testing.T) {
Email: database.ToNullString("bob@example.com"),
ClientKDFIteration: 200000,
}
- testutils.MustExec(t, db.Save(&alice), "saving alice")
- testutils.MustExec(t, db.Save(&bob), "saving bob")
+ testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
+ testutils.MustExec(t, testutils.DB.Save(&bob), "saving bob")
testCases := []struct {
email string
@@ -121,12 +111,11 @@ func TestClassicPresignin_MissingParams(t *testing.T) {
}
func TestClassicSignin(t *testing.T) {
- db := database.DBConn
defer testutils.ClearData()
user := testutils.SetupUserData()
alice := testutils.SetupClassicAccountData(user, "alice@example.com")
- testutils.MustExec(t, db.Save(&alice), "saving alice")
+ testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
// Setup
server := MustNewServer(t, &App{
@@ -145,8 +134,8 @@ func TestClassicSignin(t *testing.T) {
var sessionCount int
var session database.Session
- testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
- testutils.MustExec(t, db.First(&session), "getting session")
+ testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, testutils.DB.First(&session), "getting session")
var got SessionResponse
if err := json.NewDecoder(res.Body).Decode(&got); err != nil {
@@ -165,7 +154,7 @@ func TestClassicSignin(t *testing.T) {
}
func TestClassicSignin_Failure(t *testing.T) {
- db := database.DBConn
+
defer testutils.ClearData()
//password: correctbattery
@@ -183,8 +172,8 @@ func TestClassicSignin_Failure(t *testing.T) {
// plain authKey: DN4d/teaq1I2bVYZ7QWaah4Fu7q2y2N4yJNZk76hFHw=
AuthKeyHash: "fGOMHHAw9G7CH4Gv2EM1ZcZZklC1a55fS3QJ0qQVp4k=",
}
- testutils.MustExec(t, db.Save(&alice), "saving alice")
- testutils.MustExec(t, db.Save(&bob), "saving bob")
+ testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
+ testutils.MustExec(t, testutils.DB.Save(&bob), "saving bob")
testCases := []struct {
email string
diff --git a/pkg/server/handlers/health_test.go b/pkg/server/handlers/health_test.go
index 6a574f6d..78067e82 100644
--- a/pkg/server/handlers/health_test.go
+++ b/pkg/server/handlers/health_test.go
@@ -25,13 +25,13 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/testutils"
+ "github.com/jinzhu/gorm"
)
func TestCheckHealth(t *testing.T) {
- defer testutils.ClearData()
-
// Setup
server := MustNewServer(t, &App{
+ DB: &gorm.DB{},
Clock: clock.NewMock(),
})
defer server.Close()
diff --git a/pkg/server/handlers/main_test.go b/pkg/server/handlers/main_test.go
new file mode 100644
index 00000000..8e7028d0
--- /dev/null
+++ b/pkg/server/handlers/main_test.go
@@ -0,0 +1,20 @@
+package handlers
+
+import (
+ "os"
+ "testing"
+
+ "github.com/dnote/dnote/pkg/server/mailer"
+ "github.com/dnote/dnote/pkg/server/testutils"
+)
+
+func TestMain(m *testing.M) {
+ testutils.InitTestDB()
+ templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
+ mailer.InitTemplates(&templatePath)
+
+ code := m.Run()
+ testutils.ClearData()
+
+ os.Exit(code)
+}
diff --git a/pkg/server/handlers/notes.go b/pkg/server/handlers/notes.go
index c8ac1335..a850a82d 100644
--- a/pkg/server/handlers/notes.go
+++ b/pkg/server/handlers/notes.go
@@ -86,9 +86,7 @@ func parseSearchQuery(q url.Values) string {
return escapeSearchQuery(searchStr)
}
-func getNoteBaseQuery(noteUUID string, search string) *gorm.DB {
- db := database.DBConn
-
+func getNoteBaseQuery(db *gorm.DB, noteUUID string, search string) *gorm.DB {
var conn *gorm.DB
if search != "" {
conn = selectFTSFields(db, search, &ftsParams{HighlightAll: true})
@@ -102,7 +100,7 @@ func getNoteBaseQuery(noteUUID string, search string) *gorm.DB {
}
func (a *App) getNote(w http.ResponseWriter, r *http.Request) {
- user, _, err := AuthWithSession(r, nil)
+ user, _, err := AuthWithSession(a.DB, r, nil)
if err != nil {
HandleError(w, "authenticating", err, http.StatusInternalServerError)
return
@@ -111,7 +109,7 @@ func (a *App) getNote(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
noteUUID := vars["noteUUID"]
- note, ok, err := operations.GetNote(noteUUID, user)
+ note, ok, err := operations.GetNote(a.DB, noteUUID, user)
if !ok {
RespondNotFound(w)
return
@@ -145,17 +143,17 @@ func (a *App) getNotes(w http.ResponseWriter, r *http.Request) {
}
query := r.URL.Query()
- respondGetNotes(user.ID, query, w)
+ respondGetNotes(a.DB, user.ID, query, w)
}
-func respondGetNotes(userID int, query url.Values, w http.ResponseWriter) {
+func respondGetNotes(db *gorm.DB, userID int, query url.Values, w http.ResponseWriter) {
q, err := parseGetNotesQuery(query)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
- conn := getNotesBaseQuery(userID, q)
+ conn := getNotesBaseQuery(db, userID, q)
var total int
if err := conn.Model(database.Note{}).Count(&total).Error; err != nil {
@@ -274,9 +272,7 @@ func getDateBounds(year, month int) (int64, int64) {
return lower, upper
}
-func getNotesBaseQuery(userID int, q getNotesQuery) *gorm.DB {
- db := database.DBConn
-
+func getNotesBaseQuery(db *gorm.DB, userID int, q getNotesQuery) *gorm.DB {
conn := db.Where(
"notes.user_id = ? AND notes.deleted = ? AND notes.encrypted = ?",
userID, false, q.Encrypted,
@@ -317,8 +313,7 @@ func (a *App) legacyGetNotes(w http.ResponseWriter, r *http.Request) {
}
var notes []database.Note
- db := database.DBConn
- if err := db.Where("user_id = ? AND encrypted = true", user.ID).Find(¬es).Error; err != nil {
+ if err := a.DB.Where("user_id = ? AND encrypted = true", user.ID).Find(¬es).Error; err != nil {
HandleError(w, "finding notes", err, http.StatusInternalServerError)
return
}
diff --git a/pkg/server/handlers/notes_test.go b/pkg/server/handlers/notes_test.go
index ee9cc6ea..22ec8918 100644
--- a/pkg/server/handlers/notes_test.go
+++ b/pkg/server/handlers/notes_test.go
@@ -28,16 +28,12 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
- "github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-}
-
func getExpectedNotePayload(n database.Note, b database.Book, u database.User) presenters.Note {
return presenters.Note{
UUID: n.UUID,
@@ -59,11 +55,12 @@ func getExpectedNotePayload(n database.Note, b database.Book, u database.User) p
}
func TestGetNotes(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -75,17 +72,17 @@ func TestGetNotes(t *testing.T) {
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, db.Save(&b2), "preparing b2")
+ testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: anotherUser.ID,
Label: "css",
}
- testutils.MustExec(t, db.Save(&b3), "preparing b3")
+ testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
n1 := database.Note{
UserID: user.ID,
@@ -95,7 +92,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, db.Save(&n1), "preparing n1")
+ testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
n2 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
@@ -104,7 +101,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 11, 22, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, db.Save(&n2), "preparing n2")
+ testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
n3 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
@@ -113,7 +110,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2017, time.January, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, db.Save(&n3), "preparing n3")
+ testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
n4 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
@@ -122,7 +119,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.September, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, db.Save(&n4), "preparing n4")
+ testutils.MustExec(t, testutils.DB.Save(&n4), "preparing n4")
n5 := database.Note{
UserID: anotherUser.ID,
BookUUID: b3.UUID,
@@ -131,7 +128,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, db.Save(&n5), "preparing n5")
+ testutils.MustExec(t, testutils.DB.Save(&n5), "preparing n5")
n6 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
@@ -140,7 +137,7 @@ func TestGetNotes(t *testing.T) {
Deleted: true,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, db.Save(&n6), "preparing n6")
+ testutils.MustExec(t, testutils.DB.Save(&n6), "preparing n6")
// Execute
req := testutils.MakeReq(server, "GET", "/notes?year=2018&month=8", "")
@@ -155,8 +152,8 @@ func TestGetNotes(t *testing.T) {
}
var n2Record, n1Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
- testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
expected := GetNotesResponse{
Notes: []presenters.Note{
@@ -170,11 +167,12 @@ func TestGetNotes(t *testing.T) {
}
func TestGetNote(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -186,7 +184,7 @@ func TestGetNote(t *testing.T) {
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
privateNote := database.Note{
UserID: user.ID,
@@ -194,20 +192,20 @@ func TestGetNote(t *testing.T) {
Body: "privateNote content",
Public: false,
}
- testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
+ testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
Body: "publicNote content",
Public: true,
}
- testutils.MustExec(t, db.Save(&publicNote), "preparing publicNote")
+ testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing publicNote")
deletedNote := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
Deleted: true,
}
- testutils.MustExec(t, db.Save(&deletedNote), "preparing publicNote")
+ testutils.MustExec(t, testutils.DB.Save(&deletedNote), "preparing publicNote")
t.Run("owner accessing private note", func(t *testing.T) {
// Execute
@@ -224,7 +222,7 @@ func TestGetNote(t *testing.T) {
}
var n1Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", privateNote.UUID).First(&n1Record), "finding n1Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", privateNote.UUID).First(&n1Record), "finding n1Record")
expected := getExpectedNotePayload(n1Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@@ -245,7 +243,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@@ -266,7 +264,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@@ -304,7 +302,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
diff --git a/pkg/server/handlers/repetition_rules.go b/pkg/server/handlers/repetition_rules.go
index cc305cb1..f363f557 100644
--- a/pkg/server/handlers/repetition_rules.go
+++ b/pkg/server/handlers/repetition_rules.go
@@ -23,9 +23,9 @@ import (
"net/http"
"time"
+ "github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/presenters"
- "github.com/dnote/dnote/pkg/server/database"
"github.com/gorilla/mux"
"github.com/pkg/errors"
)
@@ -45,9 +45,8 @@ func (a *App) getRepetitionRule(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
var repetitionRule database.RepetitionRule
- if err := db.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").Find(&repetitionRule).Error; err != nil {
+ if err := a.DB.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").Find(&repetitionRule).Error; err != nil {
HandleError(w, "getting repetition rules", err, http.StatusInternalServerError)
return
}
@@ -63,9 +62,8 @@ func (a *App) getRepetitionRules(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
var repetitionRules []database.RepetitionRule
- if err := db.Where("user_id = ?", user.ID).Preload("Books").Order("last_active DESC").Find(&repetitionRules).Error; err != nil {
+ if err := a.DB.Where("user_id = ?", user.ID).Preload("Books").Order("last_active DESC").Find(&repetitionRules).Error; err != nil {
HandleError(w, "getting repetition rules", err, http.StatusInternalServerError)
return
}
@@ -288,9 +286,8 @@ func (a *App) createRepetitionRule(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
var books []database.Book
- if err := db.Where("user_id = ? AND uuid IN (?)", user.ID, params.GetBookUUIDs()).Find(&books).Error; err != nil {
+ if err := a.DB.Where("user_id = ? AND uuid IN (?)", user.ID, params.GetBookUUIDs()).Find(&books).Error; err != nil {
HandleError(w, "finding books", nil, http.StatusInternalServerError)
return
}
@@ -313,7 +310,7 @@ func (a *App) createRepetitionRule(w http.ResponseWriter, r *http.Request) {
NoteCount: params.GetNoteCount(),
Enabled: params.GetEnabled(),
}
- if err := db.Create(&record).Error; err != nil {
+ if err := a.DB.Create(&record).Error; err != nil {
HandleError(w, "creating a repetition rule", err, http.StatusInternalServerError)
return
}
@@ -346,10 +343,8 @@ func (a *App) deleteRepetitionRule(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
repetitionRuleUUID := vars["repetitionRuleUUID"]
- db := database.DBConn
-
var rule database.RepetitionRule
- conn := db.Where("uuid = ? AND user_id = ?", repetitionRuleUUID, user.ID).First(&rule)
+ conn := a.DB.Where("uuid = ? AND user_id = ?", repetitionRuleUUID, user.ID).First(&rule)
if conn.RecordNotFound() {
http.Error(w, "Not found", http.StatusNotFound)
@@ -359,7 +354,7 @@ func (a *App) deleteRepetitionRule(w http.ResponseWriter, r *http.Request) {
return
}
- if err := db.Exec("DELETE from repetition_rules WHERE uuid = ?", rule.UUID).Error; err != nil {
+ if err := a.DB.Exec("DELETE from repetition_rules WHERE uuid = ?", rule.UUID).Error; err != nil {
HandleError(w, "deleting the repetition rule", err, http.StatusInternalServerError)
}
@@ -382,8 +377,7 @@ func (a *App) updateRepetitionRule(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
- tx := db.Begin()
+ tx := a.DB.Begin()
var repetitionRule database.RepetitionRule
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").First(&repetitionRule).Error; err != nil {
diff --git a/pkg/server/handlers/repetition_rules_test.go b/pkg/server/handlers/repetition_rules_test.go
index 3015d41c..97c9eadd 100644
--- a/pkg/server/handlers/repetition_rules_test.go
+++ b/pkg/server/handlers/repetition_rules_test.go
@@ -27,22 +27,19 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
- "github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-}
-
func TestGetRepetitionRule(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -53,7 +50,7 @@ func TestGetRepetitionRule(t *testing.T) {
USN: 11,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing book1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
r1 := database.RepetitionRule{
Title: "Rule 1",
@@ -66,7 +63,7 @@ func TestGetRepetitionRule(t *testing.T) {
Books: []database.Book{b1},
NoteCount: 5,
}
- testutils.MustExec(t, db.Save(&r1), "preparing rule1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
req := testutils.MakeReq(server, "GET", fmt.Sprintf("/repetition_rules/%s", r1.UUID), "")
@@ -81,9 +78,9 @@ func TestGetRepetitionRule(t *testing.T) {
}
var r1Record database.RepetitionRule
- testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
var b1Record database.Book
- testutils.MustExec(t, db.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
expected := presenters.RepetitionRule{
UUID: r1Record.UUID,
@@ -112,11 +109,12 @@ func TestGetRepetitionRule(t *testing.T) {
}
func TestGetRepetitionRules(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -127,7 +125,7 @@ func TestGetRepetitionRules(t *testing.T) {
USN: 11,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing book1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
r1 := database.RepetitionRule{
Title: "Rule 1",
@@ -140,7 +138,7 @@ func TestGetRepetitionRules(t *testing.T) {
Books: []database.Book{b1},
NoteCount: 5,
}
- testutils.MustExec(t, db.Save(&r1), "preparing rule1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
r2 := database.RepetitionRule{
Title: "Rule 2",
@@ -153,7 +151,7 @@ func TestGetRepetitionRules(t *testing.T) {
Books: []database.Book{},
NoteCount: 5,
}
- testutils.MustExec(t, db.Save(&r2), "preparing rule2")
+ testutils.MustExec(t, testutils.DB.Save(&r2), "preparing rule2")
// Execute
req := testutils.MakeReq(server, "GET", "/repetition_rules", "")
@@ -168,10 +166,10 @@ func TestGetRepetitionRules(t *testing.T) {
}
var r1Record, r2Record database.RepetitionRule
- testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
- testutils.MustExec(t, db.Where("uuid = ?", r2.UUID).First(&r2Record), "finding r2Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", r2.UUID).First(&r2Record), "finding r2Record")
var b1Record database.Book
- testutils.MustExec(t, db.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
expected := []presenters.RepetitionRule{
{
@@ -217,8 +215,8 @@ func TestGetRepetitionRules(t *testing.T) {
func TestCreateRepetitionRules(t *testing.T) {
t.Run("all books", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
c := clock.NewMock()
@@ -226,6 +224,7 @@ func TestCreateRepetitionRules(t *testing.T) {
c.SetNow(t0)
server := MustNewServer(t, &App{
+
Clock: c,
})
defer server.Close()
@@ -250,11 +249,11 @@ func TestCreateRepetitionRules(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
var ruleCount int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
var rule database.RepetitionRule
- testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
+ testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
assert.Equal(t, rule.Title, "Rule 1", "rule Title mismatch")
@@ -275,8 +274,8 @@ func TestCreateRepetitionRules(t *testing.T) {
}
for _, tc := range bookDomainTestCases {
t.Run(tc, func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
c := clock.NewMock()
@@ -284,6 +283,7 @@ func TestCreateRepetitionRules(t *testing.T) {
c.SetNow(t0)
server := MustNewServer(t, &App{
+
Clock: c,
})
defer server.Close()
@@ -294,7 +294,7 @@ func TestCreateRepetitionRules(t *testing.T) {
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
// Execute
dat := fmt.Sprintf(`{
@@ -314,14 +314,14 @@ func TestCreateRepetitionRules(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
var ruleCount int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
var rule database.RepetitionRule
- testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
+ testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
var b1Record database.Book
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
assert.Equal(t, rule.Title, "Rule 1", "rule Title mismatch")
@@ -339,14 +339,15 @@ func TestCreateRepetitionRules(t *testing.T) {
}
func TestUpdateRepetitionRules(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
c := clock.NewMock()
t0 := time.Date(2009, time.November, 1, 2, 3, 4, 5, time.UTC)
c.SetNow(t0)
server := MustNewServer(t, &App{
+
Clock: c,
})
defer server.Close()
@@ -367,13 +368,13 @@ func TestUpdateRepetitionRules(t *testing.T) {
Books: []database.Book{},
NoteCount: 20,
}
- testutils.MustExec(t, db.Save(&r1), "preparing r1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
b1 := database.Book{
UserID: user.ID,
USN: 11,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing book1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
dat := fmt.Sprintf(`{
"title": "Rule 1 - edited",
@@ -393,14 +394,14 @@ func TestUpdateRepetitionRules(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var totalRuleCount int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
assert.Equalf(t, totalRuleCount, 1, "reperition rule count mismatch")
var rule database.RepetitionRule
- testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
+ testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
var b1Record database.Book
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
assert.Equal(t, rule.Title, "Rule 1 - edited", "rule Title mismatch")
@@ -416,11 +417,12 @@ func TestUpdateRepetitionRules(t *testing.T) {
}
func TestDeleteRepetitionRules(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -439,7 +441,7 @@ func TestDeleteRepetitionRules(t *testing.T) {
Books: []database.Book{},
NoteCount: 20,
}
- testutils.MustExec(t, db.Save(&r1), "preparing r1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
r2 := database.RepetitionRule{
Title: "Rule 1",
@@ -452,7 +454,7 @@ func TestDeleteRepetitionRules(t *testing.T) {
Books: []database.Book{},
NoteCount: 20,
}
- testutils.MustExec(t, db.Save(&r2), "preparing r2")
+ testutils.MustExec(t, testutils.DB.Save(&r2), "preparing r2")
endpoint := fmt.Sprintf("/repetition_rules/%s", r1.UUID)
req := testutils.MakeReq(server, "DELETE", endpoint, "")
@@ -462,11 +464,11 @@ func TestDeleteRepetitionRules(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var totalRuleCount int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
assert.Equalf(t, totalRuleCount, 1, "reperition rule count mismatch")
var r2Count int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Where("id = ?", r2.ID).Count(&r2Count), "counting r2")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Where("id = ?", r2.ID).Count(&r2Count), "counting r2")
assert.Equalf(t, r2Count, 1, "r2 count mismatch")
}
@@ -541,11 +543,12 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case - create %d", idx), func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -560,13 +563,13 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
var ruleCount int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 0, "reperition rule count mismatch")
})
t.Run(fmt.Sprintf("test case %d - update", idx), func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
user := testutils.SetupUserData()
@@ -581,15 +584,16 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
Books: []database.Book{},
NoteCount: 20,
}
- testutils.MustExec(t, db.Save(&r1), "preparing r1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
b1 := database.Book{
UserID: user.ID,
USN: 11,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing book1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -602,7 +606,7 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
var ruleCount int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
})
}
@@ -624,11 +628,12 @@ func TestCreateRepetitionRules_BadRequest(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -643,7 +648,7 @@ func TestCreateRepetitionRules_BadRequest(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
var ruleCount int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 0, "reperition rule count mismatch")
})
}
diff --git a/pkg/server/handlers/routes.go b/pkg/server/handlers/routes.go
index 42a8b849..b5ac7a2e 100644
--- a/pkg/server/handlers/routes.go
+++ b/pkg/server/handlers/routes.go
@@ -31,6 +31,7 @@ import (
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
"github.com/gorilla/mux"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stripe/stripe-go"
)
@@ -63,28 +64,6 @@ func parseAuthHeader(h string) (authHeader, error) {
return parsed, nil
}
-func legacyAuth(next http.HandlerFunc) http.HandlerFunc {
- db := database.DBConn
-
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- c, err := r.Cookie("api_key")
- if err != nil {
- http.Error(w, "Invalid API key", http.StatusUnauthorized)
- return
- }
-
- apiKey := c.Value
- var user database.User
- if db.Where("api_key = ?", apiKey).First(&user).RecordNotFound() {
- http.Error(w, "Invalid API key", http.StatusUnauthorized)
- return
- }
-
- ctx := context.WithValue(r.Context(), helpers.KeyUser, user)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
-}
-
// getSessionKeyFromCookie reads and returns a session key from the cookie sent by the
// request. If no session key is found, it returns an empty string
func getSessionKeyFromCookie(r *http.Request) (string, error) {
@@ -138,8 +117,7 @@ func getCredential(r *http.Request) (string, error) {
}
// AuthWithSession performs user authentication with session
-func AuthWithSession(r *http.Request, p *AuthMiddlewareParams) (database.User, bool, error) {
- db := database.DBConn
+func AuthWithSession(db *gorm.DB, r *http.Request, p *AuthMiddlewareParams) (database.User, bool, error) {
var user database.User
sessionKey, err := getCredential(r)
@@ -174,8 +152,7 @@ func AuthWithSession(r *http.Request, p *AuthMiddlewareParams) (database.User, b
return user, true, nil
}
-func authWithToken(r *http.Request, tokenType string, p *AuthMiddlewareParams) (database.User, database.Token, bool, error) {
- db := database.DBConn
+func authWithToken(db *gorm.DB, r *http.Request, tokenType string, p *AuthMiddlewareParams) (database.User, database.Token, bool, error) {
var user database.User
var token database.Token
@@ -208,9 +185,9 @@ type AuthMiddlewareParams struct {
ProOnly bool
}
-func auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
+func (a *App) auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user, ok, err := AuthWithSession(r, p)
+ user, ok, err := AuthWithSession(a.DB, r, p)
if !ok {
respondUnauthorized(w)
return
@@ -231,9 +208,9 @@ func auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
})
}
-func tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams) http.HandlerFunc {
+func (a *App) tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user, token, ok, err := authWithToken(r, tokenType, p)
+ user, token, ok, err := authWithToken(a.DB, r, tokenType, p)
if err != nil {
// log the error and continue
log.ErrorWrap(err, "authenticating with token")
@@ -245,7 +222,7 @@ func tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams)
ctx = context.WithValue(ctx, helpers.KeyToken, token)
} else {
// If token-based auth fails, fall back to session-based auth
- user, ok, err = AuthWithSession(r, p)
+ user, ok, err = AuthWithSession(a.DB, r, p)
if err != nil {
HandleError(w, "authenticating with session", err, http.StatusInternalServerError)
return
@@ -325,6 +302,7 @@ func applyMiddleware(h http.HandlerFunc, rateLimit bool) http.Handler {
// App is an application configuration
type App struct {
+ DB *gorm.DB
Clock clock.Clock
StripeAPIBackend stripe.Backend
WebURL string
@@ -334,6 +312,9 @@ func (a *App) validate() error {
if a.WebURL == "" {
return errors.New("WebURL is empty")
}
+ if a.DB == nil {
+ return errors.New("DB is empty")
+ }
return nil
}
@@ -364,51 +345,50 @@ func NewRouter(app *App) (*mux.Router, error) {
var routes = []Route{
// internal
{"GET", "/health", app.checkHealth, false},
- {"GET", "/me", auth(app.getMe, nil), true},
- {"POST", "/verification-token", auth(app.createVerificationToken, nil), true},
+ {"GET", "/me", app.auth(app.getMe, nil), true},
+ {"POST", "/verification-token", app.auth(app.createVerificationToken, nil), true},
{"PATCH", "/verify-email", app.verifyEmail, true},
{"POST", "/reset-token", app.createResetToken, true},
{"PATCH", "/reset-password", app.resetPassword, true},
- {"PATCH", "/account/profile", auth(app.updateProfile, nil), true},
- {"PATCH", "/account/password", auth(app.updatePassword, nil), true},
- {"GET", "/account/email-preference", tokenAuth(app.getEmailPreference, database.TokenTypeEmailPreference, nil), true},
- {"PATCH", "/account/email-preference", tokenAuth(app.updateEmailPreference, database.TokenTypeEmailPreference, nil), true},
- {"POST", "/subscriptions", auth(app.createSub, nil), true},
- {"PATCH", "/subscriptions", auth(app.updateSub, nil), true},
+ {"PATCH", "/account/profile", app.auth(app.updateProfile, nil), true},
+ {"PATCH", "/account/password", app.auth(app.updatePassword, nil), true},
+ {"GET", "/account/email-preference", app.tokenAuth(app.getEmailPreference, database.TokenTypeEmailPreference, nil), true},
+ {"PATCH", "/account/email-preference", app.tokenAuth(app.updateEmailPreference, database.TokenTypeEmailPreference, nil), true},
+ {"POST", "/subscriptions", app.auth(app.createSub, nil), true},
+ {"PATCH", "/subscriptions", app.auth(app.updateSub, nil), true},
{"POST", "/webhooks/stripe", app.stripeWebhook, true},
- {"GET", "/subscriptions", auth(app.getSub, nil), true},
- {"GET", "/stripe_source", auth(app.getStripeSource, nil), true},
- {"PATCH", "/stripe_source", auth(app.updateStripeSource, nil), true},
- {"GET", "/notes", auth(app.getNotes, &proOnly), false},
+ {"GET", "/subscriptions", app.auth(app.getSub, nil), true},
+ {"GET", "/stripe_source", app.auth(app.getStripeSource, nil), true},
+ {"PATCH", "/stripe_source", app.auth(app.updateStripeSource, nil), true},
+ {"GET", "/notes", app.auth(app.getNotes, &proOnly), false},
{"GET", "/notes/{noteUUID}", app.getNote, true},
- {"GET", "/calendar", auth(app.getCalendar, &proOnly), true},
- {"GET", "/repetition_rules", auth(app.getRepetitionRules, &proOnly), true},
- {"GET", "/repetition_rules/{repetitionRuleUUID}", tokenAuth(app.getRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
- {"POST", "/repetition_rules", auth(app.createRepetitionRule, &proOnly), true},
- {"PATCH", "/repetition_rules/{repetitionRuleUUID}", tokenAuth(app.updateRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
- {"DELETE", "/repetition_rules/{repetitionRuleUUID}", auth(app.deleteRepetitionRule, &proOnly), true},
+ {"GET", "/calendar", app.auth(app.getCalendar, &proOnly), true},
+ {"GET", "/repetition_rules", app.auth(app.getRepetitionRules, &proOnly), true},
+ {"GET", "/repetition_rules/{repetitionRuleUUID}", app.tokenAuth(app.getRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
+ {"POST", "/repetition_rules", app.auth(app.createRepetitionRule, &proOnly), true},
+ {"PATCH", "/repetition_rules/{repetitionRuleUUID}", app.tokenAuth(app.updateRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
+ {"DELETE", "/repetition_rules/{repetitionRuleUUID}", app.auth(app.deleteRepetitionRule, &proOnly), true},
// migration of classic users
{"GET", "/classic/presignin", cors(app.classicPresignin), true},
{"POST", "/classic/signin", cors(app.classicSignin), true},
- {"PATCH", "/classic/migrate", auth(app.classicMigrate, &proOnly), true},
- {"GET", "/classic/notes", auth(app.classicGetNotes, nil), true},
- {"PATCH", "/classic/set-password", auth(app.classicSetPassword, nil), true},
+ {"PATCH", "/classic/migrate", app.auth(app.classicMigrate, &proOnly), true},
+ {"GET", "/classic/notes", app.auth(app.classicGetNotes, nil), true},
+ {"PATCH", "/classic/set-password", app.auth(app.classicSetPassword, nil), true},
// v3
- {"GET", "/v3/sync/fragment", cors(auth(app.GetSyncFragment, &proOnly)), true},
- {"GET", "/v3/sync/state", cors(auth(app.GetSyncState, &proOnly)), true},
+ {"GET", "/v3/sync/fragment", cors(app.auth(app.GetSyncFragment, &proOnly)), true},
+ {"GET", "/v3/sync/state", cors(app.auth(app.GetSyncState, &proOnly)), true},
{"OPTIONS", "/v3/books", cors(app.BooksOptions), true},
- {"GET", "/v3/books", cors(auth(app.GetBooks, &proOnly)), true},
- {"GET", "/v3/books/{bookUUID}", cors(auth(app.GetBook, &proOnly)), true},
- {"POST", "/v3/books", cors(auth(app.CreateBook, &proOnly)), true},
- {"PATCH", "/v3/books/{bookUUID}", cors(auth(app.UpdateBook, &proOnly)), false},
- {"DELETE", "/v3/books/{bookUUID}", cors(auth(app.DeleteBook, &proOnly)), false},
- {"GET", "/v3/demo/books", app.GetDemoBooks, true},
+ {"GET", "/v3/books", cors(app.auth(app.GetBooks, &proOnly)), true},
+ {"GET", "/v3/books/{bookUUID}", cors(app.auth(app.GetBook, &proOnly)), true},
+ {"POST", "/v3/books", cors(app.auth(app.CreateBook, &proOnly)), true},
+ {"PATCH", "/v3/books/{bookUUID}", cors(app.auth(app.UpdateBook, &proOnly)), false},
+ {"DELETE", "/v3/books/{bookUUID}", cors(app.auth(app.DeleteBook, &proOnly)), false},
{"OPTIONS", "/v3/notes", cors(app.NotesOptions), true},
- {"POST", "/v3/notes", cors(auth(app.CreateNote, &proOnly)), true},
- {"PATCH", "/v3/notes/{noteUUID}", auth(app.UpdateNote, &proOnly), false},
- {"DELETE", "/v3/notes/{noteUUID}", auth(app.DeleteNote, &proOnly), false},
+ {"POST", "/v3/notes", cors(app.auth(app.CreateNote, &proOnly)), true},
+ {"PATCH", "/v3/notes/{noteUUID}", app.auth(app.UpdateNote, &proOnly), false},
+ {"DELETE", "/v3/notes/{noteUUID}", app.auth(app.DeleteNote, &proOnly), false},
{"POST", "/v3/signin", cors(app.signin), true},
{"OPTIONS", "/v3/signout", cors(app.signoutOptions), true},
{"POST", "/v3/signout", cors(app.signout), true},
diff --git a/pkg/server/handlers/routes_test.go b/pkg/server/handlers/routes_test.go
index 53227e7c..3a43e725 100644
--- a/pkg/server/handlers/routes_test.go
+++ b/pkg/server/handlers/routes_test.go
@@ -29,13 +29,10 @@ import (
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/testutils"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-}
-
func TestGetSessionKeyFromCookie(t *testing.T) {
testCases := []struct {
cookie *http.Cookie
@@ -185,10 +182,8 @@ func TestGetCredential(t *testing.T) {
}
func TestAuthMiddleware(t *testing.T) {
- defer testutils.ClearData()
- // set up
- db := database.DBConn
+ defer testutils.ClearData()
user := testutils.SetupUserData()
session := database.Session{
@@ -196,18 +191,19 @@ func TestAuthMiddleware(t *testing.T) {
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session), "preparing session")
+ testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
session2 := database.Session{
Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
UserID: user.ID,
ExpiresAt: time.Now().Add(-time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session2), "preparing session")
+ testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
- server := httptest.NewServer(auth(handler, nil))
+ app := App{DB: testutils.DB}
+ server := httptest.NewServer(app.auth(handler, nil))
defer server.Close()
t.Run("with header", func(t *testing.T) {
@@ -300,24 +296,23 @@ func TestAuthMiddleware(t *testing.T) {
}
func TestAuthMiddleware_ProOnly(t *testing.T) {
+
defer testutils.ClearData()
- // set up
- db := database.DBConn
-
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("cloud", false), "preparing session")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
session := database.Session{
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session), "preparing session")
+ testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
- server := httptest.NewServer(auth(handler, &AuthMiddlewareParams{
+ app := App{DB: testutils.DB}
+ server := httptest.NewServer(app.auth(handler, &AuthMiddlewareParams{
ProOnly: true,
}))
defer server.Close()
@@ -390,10 +385,8 @@ func TestAuthMiddleware_ProOnly(t *testing.T) {
}
func TestTokenAuthMiddleWare(t *testing.T) {
- defer testutils.ClearData()
- // set up
- db := database.DBConn
+ defer testutils.ClearData()
user := testutils.SetupUserData()
tok := database.Token{
@@ -401,18 +394,19 @@ func TestTokenAuthMiddleWare(t *testing.T) {
Type: database.TokenTypeEmailPreference,
Value: "xpwFnc0MdllFUePDq9DLeQ==",
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
session := database.Session{
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session), "preparing session")
+ testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
- server := httptest.NewServer(tokenAuth(handler, database.TokenTypeEmailPreference, nil))
+ app := App{DB: testutils.DB}
+ server := httptest.NewServer(app.tokenAuth(handler, database.TokenTypeEmailPreference, nil))
defer server.Close()
t.Run("with token", func(t *testing.T) {
@@ -521,30 +515,29 @@ func TestTokenAuthMiddleWare(t *testing.T) {
}
func TestTokenAuthMiddleWare_ProOnly(t *testing.T) {
+
defer testutils.ClearData()
- // set up
- db := database.DBConn
-
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("cloud", false), "preparing session")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
tok := database.Token{
UserID: user.ID,
Type: database.TokenTypeEmailPreference,
Value: "xpwFnc0MdllFUePDq9DLeQ==",
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
session := database.Session{
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session), "preparing session")
+ testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
- server := httptest.NewServer(tokenAuth(handler, database.TokenTypeEmailPreference, &AuthMiddlewareParams{
+ app := App{DB: testutils.DB}
+ server := httptest.NewServer(app.tokenAuth(handler, database.TokenTypeEmailPreference, &AuthMiddlewareParams{
ProOnly: true,
}))
defer server.Close()
@@ -682,6 +675,7 @@ func TestNotSupportedVersions(t *testing.T) {
// setup
server := MustNewServer(t, &App{
+ DB: &gorm.DB{},
Clock: clock.NewMock(),
})
defer server.Close()
diff --git a/pkg/server/handlers/subscription.go b/pkg/server/handlers/subscription.go
index 1fe700b6..1f727ca8 100644
--- a/pkg/server/handlers/subscription.go
+++ b/pkg/server/handlers/subscription.go
@@ -26,9 +26,9 @@ import (
"os"
"strings"
+ "github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/operations"
- "github.com/dnote/dnote/pkg/server/database"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stripe/stripe-go"
@@ -138,8 +138,7 @@ func (a *App) createSub(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
- tx := db.Begin()
+ tx := a.DB.Begin()
if err := tx.Model(&user).
Update(map[string]interface{}{
@@ -431,8 +430,7 @@ func (a *App) updateStripeSource(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
- tx := db.Begin()
+ tx := a.DB.Begin()
if err := tx.Model(&user).
Update(map[string]interface{}{
@@ -532,7 +530,7 @@ func (a *App) stripeWebhook(w http.ResponseWriter, req *http.Request) {
return
}
- operations.MarkUnsubscribed(subscription.Customer.ID)
+ operations.MarkUnsubscribed(a.DB, subscription.Customer.ID)
}
default:
{
diff --git a/pkg/server/handlers/testutils.go b/pkg/server/handlers/testutils.go
index ca979b2f..187975fa 100644
--- a/pkg/server/handlers/testutils.go
+++ b/pkg/server/handlers/testutils.go
@@ -23,6 +23,7 @@ import (
"os"
"testing"
+ "github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
@@ -30,6 +31,7 @@ import (
// with the given app paratmers
func MustNewServer(t *testing.T, app *App) *httptest.Server {
app.WebURL = os.Getenv("WebURL")
+ app.DB = testutils.DB
r, err := NewRouter(app)
if err != nil {
diff --git a/pkg/server/handlers/user.go b/pkg/server/handlers/user.go
index e7acc47f..82432d57 100644
--- a/pkg/server/handlers/user.go
+++ b/pkg/server/handlers/user.go
@@ -23,11 +23,12 @@ import (
"net/http"
"time"
- "github.com/dnote/dnote/pkg/server/helpers"
- "github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
"github.com/dnote/dnote/pkg/server/mailer"
+ "github.com/dnote/dnote/pkg/server/presenters"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
@@ -38,8 +39,6 @@ type updateProfilePayload struct {
// updateProfile updates user
func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@@ -60,13 +59,13 @@ func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- err = db.Where("user_id = ?", user.ID).First(&account).Error
+ err = a.DB.Where("user_id = ?", user.ID).First(&account).Error
if err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
- tx := db.Begin()
+ tx := a.DB.Begin()
if err := tx.Save(&user).Error; err != nil {
tx.Rollback()
HandleError(w, "saving user", err, http.StatusInternalServerError)
@@ -87,7 +86,7 @@ func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
tx.Commit()
- respondWithSession(w, user.ID, http.StatusOK)
+ respondWithSession(a.DB, w, user.ID, http.StatusOK)
}
type updateEmailPayload struct {
@@ -97,9 +96,7 @@ type updateEmailPayload struct {
NewAuthKey string `json:"new_auth_key"`
}
-func respondWithCalendar(w http.ResponseWriter, userID int) {
- db := database.DBConn
-
+func respondWithCalendar(db *gorm.DB, w http.ResponseWriter, userID int) {
rows, err := db.Table("notes").Select("COUNT(id), date(to_timestamp(added_on/1000000000)) AS added_date").
Where("user_id = ?", userID).
Group("added_date").
@@ -132,22 +129,10 @@ func (a *App) getCalendar(w http.ResponseWriter, r *http.Request) {
return
}
- respondWithCalendar(w, user.ID)
-}
-
-func (a *App) getDemoCalendar(w http.ResponseWriter, r *http.Request) {
- userID, err := helpers.GetDemoUserID()
- if err != nil {
- HandleError(w, "finding demo user", err, http.StatusInternalServerError)
- return
- }
-
- respondWithCalendar(w, userID)
+ respondWithCalendar(a.DB, w, user.ID)
}
func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@@ -155,7 +140,7 @@ func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- err := db.Where("user_id = ?", user.ID).First(&account).Error
+ err := a.DB.Where("user_id = ?", user.ID).First(&account).Error
if err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
@@ -182,7 +167,7 @@ func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
Type: database.TokenTypeEmailVerification,
}
- if err := db.Save(&token).Error; err != nil {
+ if err := a.DB.Save(&token).Error; err != nil {
HandleError(w, "saving token", err, http.StatusInternalServerError)
return
}
@@ -212,8 +197,6 @@ type verifyEmailPayload struct {
}
func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
var params verifyEmailPayload
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
@@ -221,7 +204,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
}
var token database.Token
- if err := db.
+ if err := a.DB.
Where("value = ? AND type = ?", params.Token, database.TokenTypeEmailVerification).
First(&token).Error; err != nil {
http.Error(w, "invalid token", http.StatusBadRequest)
@@ -240,7 +223,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- if err := db.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
+ if err := a.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
@@ -249,7 +232,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
return
}
- tx := db.Begin()
+ tx := a.DB.Begin()
account.EmailVerified = true
if err := tx.Save(&account).Error; err != nil {
tx.Rollback()
@@ -264,7 +247,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
tx.Commit()
var user database.User
- if err := db.Where("id = ?", token.UserID).First(&user).Error; err != nil {
+ if err := a.DB.Where("id = ?", token.UserID).First(&user).Error; err != nil {
HandleError(w, "finding user", err, http.StatusInternalServerError)
return
}
@@ -278,8 +261,6 @@ type updateEmailPreferencePayload struct {
}
func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@@ -293,12 +274,12 @@ func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
}
var frequency database.EmailPreference
- if err := db.Where(database.EmailPreference{UserID: user.ID}).FirstOrCreate(&frequency).Error; err != nil {
+ if err := a.DB.Where(database.EmailPreference{UserID: user.ID}).FirstOrCreate(&frequency).Error; err != nil {
HandleError(w, "finding frequency", err, http.StatusInternalServerError)
return
}
- tx := db.Begin()
+ tx := a.DB.Begin()
frequency.DigestWeekly = params.DigestWeekly
if err := tx.Save(&frequency).Error; err != nil {
@@ -323,8 +304,6 @@ func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
}
func (a *App) getEmailPreference(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@@ -332,7 +311,7 @@ func (a *App) getEmailPreference(w http.ResponseWriter, r *http.Request) {
}
var pref database.EmailPreference
- if err := db.Where(database.EmailPreference{UserID: user.ID}).First(&pref).Error; err != nil {
+ if err := a.DB.Where(database.EmailPreference{UserID: user.ID}).First(&pref).Error; err != nil {
HandleError(w, "finding pref", err, http.StatusInternalServerError)
return
}
@@ -347,8 +326,6 @@ type updatePasswordPayload struct {
}
func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@@ -366,7 +343,7 @@ func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
+ if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "getting user", nil, http.StatusInternalServerError)
return
}
@@ -391,7 +368,7 @@ func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
return
}
- if err := db.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
+ if err := a.DB.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
http.Error(w, errors.Wrap(err, "updating password").Error(), http.StatusInternalServerError)
return
}
diff --git a/pkg/server/handlers/user_test.go b/pkg/server/handlers/user_test.go
index cd6de035..877ee55e 100644
--- a/pkg/server/handlers/user_test.go
+++ b/pkg/server/handlers/user_test.go
@@ -36,18 +36,14 @@ import (
"golang.org/x/crypto/bcrypt"
)
-func init() {
- testutils.InitTestDB()
-
-}
-
func TestUpdatePassword(t *testing.T) {
t.Run("success", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -64,18 +60,19 @@ func TestUpdatePassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismsatch")
var account database.Account
- testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
assert.Equal(t, passwordErr, nil, "Password mismatch")
})
t.Run("old password mismatch", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -92,16 +89,17 @@ func TestUpdatePassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
var account database.Account
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
})
t.Run("password too short", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -118,15 +116,15 @@ func TestUpdatePassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
var account database.Account
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
})
}
func TestCreateVerificationToken(t *testing.T) {
t.Run("success", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
@@ -137,6 +135,7 @@ func TestCreateVerificationToken(t *testing.T) {
mailer.InitTemplates(&templatePath)
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -154,9 +153,9 @@ func TestCreateVerificationToken(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
- testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified should not have been updated")
assert.NotEqual(t, token.Value, "", "token Value mismatch")
@@ -165,11 +164,12 @@ func TestCreateVerificationToken(t *testing.T) {
})
t.Run("already verified", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -177,7 +177,7 @@ func TestCreateVerificationToken(t *testing.T) {
user := testutils.SetupUserData()
a := testutils.SetupAccountData(user, "alice@example.com", "pass1234")
a.EmailVerified = true
- testutils.MustExec(t, db.Save(&a), "preparing account")
+ testutils.MustExec(t, testutils.DB.Save(&a), "preparing account")
// Execute
req := testutils.MakeReq(server, "POST", "/verification-token", "")
@@ -188,8 +188,8 @@ func TestCreateVerificationToken(t *testing.T) {
var account database.Account
var tokenCount int
- testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified should not have been updated")
assert.Equal(t, tokenCount, 0, "token count mismatch")
@@ -198,11 +198,12 @@ func TestCreateVerificationToken(t *testing.T) {
func TestVerifyEmail(t *testing.T) {
t.Run("success", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -214,7 +215,7 @@ func TestVerifyEmail(t *testing.T) {
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"token": "someTokenValue"}`
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
@@ -228,9 +229,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
- testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified mismatch")
assert.NotEqual(t, token.Value, "", "token value should not have been updated")
@@ -239,11 +240,12 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("used token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -258,7 +260,7 @@ func TestVerifyEmail(t *testing.T) {
Value: "someTokenValue",
UsedAt: &usedAt,
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"token": "someTokenValue"}`
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
@@ -272,9 +274,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
- testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified mismatch")
assert.NotEqual(t, token.UsedAt, nil, "token used_at mismatch")
@@ -283,11 +285,12 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("expired token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -300,8 +303,8 @@ func TestVerifyEmail(t *testing.T) {
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
- testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at")
dat := `{"token": "someTokenValue"}`
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
@@ -315,9 +318,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
- testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified mismatch")
assert.Equal(t, tokenCount, 1, "token count mismatch")
@@ -325,11 +328,12 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("already verified", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -337,14 +341,14 @@ func TestVerifyEmail(t *testing.T) {
user := testutils.SetupUserData()
a := testutils.SetupAccountData(user, "alice@example.com", "oldpass1234")
a.EmailVerified = true
- testutils.MustExec(t, db.Save(&a), "preparing account")
+ testutils.MustExec(t, testutils.DB.Save(&a), "preparing account")
tok := database.Token{
UserID: user.ID,
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"token": "someTokenValue"}`
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
@@ -358,9 +362,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
- testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified mismatch")
assert.Equal(t, tokenCount, 1, "token count mismatch")
@@ -370,11 +374,12 @@ func TestVerifyEmail(t *testing.T) {
func TestUpdateEmail(t *testing.T) {
t.Run("success", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -382,7 +387,7 @@ func TestUpdateEmail(t *testing.T) {
u := testutils.SetupUserData()
a := testutils.SetupAccountData(u, "alice@example.com", "pass1234")
a.EmailVerified = true
- testutils.MustExec(t, db.Save(&a), "updating email_verified")
+ testutils.MustExec(t, testutils.DB.Save(&a), "updating email_verified")
// Execute
dat := `{"email": "alice-new@example.com"}`
@@ -394,8 +399,8 @@ func TestUpdateEmail(t *testing.T) {
var user database.User
var account database.Account
- testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, account.Email.String, "alice-new@example.com", "email mismatch")
assert.Equal(t, account.EmailVerified, false, "EmailVerified mismatch")
@@ -404,11 +409,12 @@ func TestUpdateEmail(t *testing.T) {
func TestUpdateEmailPreference(t *testing.T) {
t.Run("with login", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -425,16 +431,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var preference database.EmailPreference
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding account")
assert.Equal(t, preference.DigestWeekly, true, "preference mismatch")
})
t.Run("with an unused token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -446,7 +453,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Type: database.TokenTypeEmailPreference,
Value: "someTokenValue",
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
// Execute
dat := `{"digest_weekly": true}`
@@ -460,9 +467,9 @@ func TestUpdateEmailPreference(t *testing.T) {
var preference database.EmailPreference
var preferenceCount int
var token database.Token
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
- testutils.MustExec(t, db.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
- testutils.MustExec(t, db.Where("id = ?", tok.ID).First(&token), "failed to find token")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
+ testutils.MustExec(t, testutils.DB.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", tok.ID).First(&token), "failed to find token")
assert.Equal(t, preferenceCount, 1, "preference count mismatch")
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
@@ -470,11 +477,12 @@ func TestUpdateEmailPreference(t *testing.T) {
})
t.Run("with nonexistent token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -486,7 +494,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Type: database.TokenTypeEmailPreference,
Value: "someTokenValue",
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"digest_weekly": false}`
url := fmt.Sprintf("/account/email-preference?token=%s", "someNonexistentToken")
@@ -499,16 +507,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var preference database.EmailPreference
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
})
t.Run("with expired token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -523,7 +532,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Value: "someTokenValue",
UsedAt: &usedAt,
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
// Execute
dat := `{"digest_weekly": false}`
@@ -535,16 +544,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var preference database.EmailPreference
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
})
t.Run("with a used but unexpired token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -558,7 +568,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Value: "someTokenValue",
UsedAt: &usedAt,
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"digest_weekly": false}`
url := fmt.Sprintf("/account/email-preference?token=%s", "someTokenValue")
@@ -571,16 +581,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var preference database.EmailPreference
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, false, "DigestWeekly mismatch")
})
t.Run("no user and no token", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -597,16 +608,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var preference database.EmailPreference
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
})
t.Run("create a record if not exists", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -617,7 +629,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Type: database.TokenTypeEmailPreference,
Value: "someTokenValue",
}
- testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
// Execute
dat := `{"digest_weekly": false}`
@@ -629,20 +641,21 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var preferenceCount int
- testutils.MustExec(t, db.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
+ testutils.MustExec(t, testutils.DB.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
assert.Equal(t, preferenceCount, 1, "preference count mismatch")
var preference database.EmailPreference
- testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
+ testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, false, "email mismatch")
})
}
func TestGetEmailPreference(t *testing.T) {
- defer testutils.ClearData()
+ defer testutils.ClearData()
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
diff --git a/pkg/server/handlers/v3_auth.go b/pkg/server/handlers/v3_auth.go
index 18cf194b..4283dae6 100644
--- a/pkg/server/handlers/v3_auth.go
+++ b/pkg/server/handlers/v3_auth.go
@@ -23,8 +23,9 @@ import (
"net/http"
"time"
- "github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/operations"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
@@ -63,9 +64,7 @@ func unsetSessionCookie(w http.ResponseWriter) {
http.SetCookie(w, &cookie)
}
-func touchLastLoginAt(user database.User) error {
- db := database.DBConn
-
+func touchLastLoginAt(db *gorm.DB, user database.User) error {
t := time.Now()
if err := db.Model(&user).Update(database.User{LastLoginAt: &t}).Error; err != nil {
return errors.Wrap(err, "updating last_login_at")
@@ -80,8 +79,6 @@ type signinPayload struct {
}
func (a *App) signin(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
var params signinPayload
err := json.NewDecoder(r.Body).Decode(¶ms)
if err != nil {
@@ -94,7 +91,7 @@ func (a *App) signin(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
- conn := db.Where("email = ?", params.Email).First(&account)
+ conn := a.DB.Where("email = ?", params.Email).First(&account)
if conn.RecordNotFound() {
http.Error(w, ErrLoginFailure.Error(), http.StatusUnauthorized)
return
@@ -111,19 +108,19 @@ func (a *App) signin(w http.ResponseWriter, r *http.Request) {
}
var user database.User
- err = db.Where("id = ?", account.UserID).First(&user).Error
+ err = a.DB.Where("id = ?", account.UserID).First(&user).Error
if err != nil {
HandleError(w, "finding user", err, http.StatusInternalServerError)
return
}
- err = operations.TouchLastLoginAt(user, db)
+ err = operations.TouchLastLoginAt(user, a.DB)
if err != nil {
http.Error(w, errors.Wrap(err, "touching login timestamp").Error(), http.StatusInternalServerError)
return
}
- respondWithSession(w, account.UserID, http.StatusOK)
+ respondWithSession(a.DB, w, account.UserID, http.StatusOK)
}
func (a *App) signoutOptions(w http.ResponseWriter, r *http.Request) {
@@ -143,7 +140,7 @@ func (a *App) signout(w http.ResponseWriter, r *http.Request) {
return
}
- err = operations.DeleteSession(database.DBConn, key)
+ err = operations.DeleteSession(a.DB, key)
if err != nil {
HandleError(w, "deleting session", nil, http.StatusInternalServerError)
return
@@ -182,8 +179,6 @@ func parseRegisterPaylaod(r *http.Request) (registerPayload, bool) {
}
func (a *App) register(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
params, ok := parseRegisterPaylaod(r)
if !ok {
http.Error(w, "invalid payload", http.StatusBadRequest)
@@ -191,7 +186,7 @@ func (a *App) register(w http.ResponseWriter, r *http.Request) {
}
var count int
- if err := db.Model(database.Account{}).Where("email = ?", params.Email).Count(&count).Error; err != nil {
+ if err := a.DB.Model(database.Account{}).Where("email = ?", params.Email).Count(&count).Error; err != nil {
HandleError(w, "checking duplicate user", err, http.StatusInternalServerError)
return
}
@@ -200,20 +195,18 @@ func (a *App) register(w http.ResponseWriter, r *http.Request) {
return
}
- user, err := operations.CreateUser(params.Email, params.Password)
+ user, err := operations.CreateUser(a.DB, params.Email, params.Password)
if err != nil {
HandleError(w, "creating user", err, http.StatusInternalServerError)
return
}
- respondWithSession(w, user.ID, http.StatusCreated)
+ respondWithSession(a.DB, w, user.ID, http.StatusCreated)
}
// respondWithSession makes a HTTP response with the session from the user with the given userID.
// It sets the HTTP-Only cookie for browser clients and also sends a JSON response for non-browser clients.
-func respondWithSession(w http.ResponseWriter, userID int, statusCode int) {
- db := database.DBConn
-
+func respondWithSession(db *gorm.DB, w http.ResponseWriter, userID int, statusCode int) {
session, err := operations.CreateSession(db, userID)
if err != nil {
HandleError(w, "creating session", nil, http.StatusBadRequest)
diff --git a/pkg/server/handlers/v3_auth_test.go b/pkg/server/handlers/v3_auth_test.go
index f9c78b3e..9d89d424 100644
--- a/pkg/server/handlers/v3_auth_test.go
+++ b/pkg/server/handlers/v3_auth_test.go
@@ -22,26 +22,17 @@ import (
"encoding/json"
"fmt"
"net/http"
- "os"
"testing"
"time"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/database"
- "github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
-func init() {
- testutils.InitTestDB()
-
- templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
- mailer.InitTemplates(&templatePath)
-}
-
func assertSessionResp(t *testing.T, res *http.Response) {
// after register, should sign in user
var got SessionResponse
@@ -51,9 +42,8 @@ func assertSessionResp(t *testing.T, res *http.Response) {
var sessionCount int
var session database.Session
- db := database.DBConn
- testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
- testutils.MustExec(t, db.First(&session), "getting session")
+ testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, testutils.DB.First(&session), "getting session")
assert.Equal(t, sessionCount, 1, "sessionCount mismatch")
assert.Equal(t, got.Key, session.Key, "session Key mismatch")
@@ -87,11 +77,12 @@ func TestRegister(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("register %s %s", tc.email, tc.password), func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -106,20 +97,20 @@ func TestRegister(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
var account database.Account
- testutils.MustExec(t, db.Where("email = ?", tc.email).First(&account), "finding account")
+ testutils.MustExec(t, testutils.DB.Where("email = ?", tc.email).First(&account), "finding account")
assert.Equal(t, account.Email.String, tc.email, "Email mismatch")
assert.NotEqual(t, account.UserID, 0, "UserID mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password))
assert.Equal(t, passwordErr, nil, "Password mismatch")
var user database.User
- testutils.MustExec(t, db.Where("id = ?", account.UserID).First(&user), "finding user")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", account.UserID).First(&user), "finding user")
assert.Equal(t, user.Cloud, false, "Cloud mismatch")
assert.Equal(t, user.StripeCustomerID, "", "StripeCustomerID mismatch")
assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch")
var repetitionRuleCount int
- testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Where("user_id = ?", account.UserID).Count(&repetitionRuleCount), "counting repetition rules")
+ testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Where("user_id = ?", account.UserID).Count(&repetitionRuleCount), "counting repetition rules")
assert.Equal(t, repetitionRuleCount, 1, "repetitionRuleCount mismatch")
// after register, should sign in user
@@ -130,11 +121,12 @@ func TestRegister(t *testing.T) {
func TestRegisterMissingParams(t *testing.T) {
t.Run("missing email", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -149,19 +141,20 @@ func TestRegisterMissingParams(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int
- testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, 0, "accountCount mismatch")
assert.Equal(t, userCount, 0, "userCount mismatch")
})
t.Run("missing password", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -176,8 +169,8 @@ func TestRegisterMissingParams(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int
- testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, 0, "accountCount mismatch")
assert.Equal(t, userCount, 0, "userCount mismatch")
@@ -185,11 +178,12 @@ func TestRegisterMissingParams(t *testing.T) {
}
func TestRegisterDuplicateEmail(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -207,12 +201,12 @@ func TestRegisterDuplicateEmail(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch")
var accountCount, userCount, verificationTokenCount int
- testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
- testutils.MustExec(t, db.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
+ testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
var user database.User
- testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
assert.Equal(t, accountCount, 1, "account count mismatch")
assert.Equal(t, userCount, 1, "user count mismatch")
@@ -222,11 +216,12 @@ func TestRegisterDuplicateEmail(t *testing.T) {
func TestSignIn(t *testing.T) {
t.Run("success", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -244,7 +239,7 @@ func TestSignIn(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var user database.User
- testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
+ testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
assert.NotEqual(t, user.LastLoginAt, nil, "LastLoginAt mismatch")
// after register, should sign in user
@@ -252,11 +247,12 @@ func TestSignIn(t *testing.T) {
})
t.Run("wrong password", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -274,20 +270,21 @@ func TestSignIn(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var user database.User
- testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
+ testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
var sessionCount int
- testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
})
t.Run("wrong email", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -305,20 +302,21 @@ func TestSignIn(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var user database.User
- testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
+ testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
assert.DeepEqual(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
var sessionCount int
- testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
})
t.Run("nonexistent email", func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -333,14 +331,14 @@ func TestSignIn(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var sessionCount int
- testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
})
}
func TestSignout(t *testing.T) {
t.Run("authenticated", func(t *testing.T) {
- db := database.DBConn
+
defer testutils.ClearData()
aliceUser := testutils.SetupUserData()
@@ -352,16 +350,17 @@ func TestSignout(t *testing.T) {
UserID: aliceUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session1), "preparing session1")
+ testutils.MustExec(t, testutils.DB.Save(&session1), "preparing session1")
session2 := database.Session{
Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=",
UserID: anotherUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session2), "preparing session2")
+ testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session2")
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -376,8 +375,8 @@ func TestSignout(t *testing.T) {
var sessionCount int
var s2 database.Session
- testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
- testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2")
+ testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2")
assert.Equal(t, sessionCount, 1, "sessionCount mismatch")
@@ -391,7 +390,7 @@ func TestSignout(t *testing.T) {
})
t.Run("unauthenticated", func(t *testing.T) {
- db := database.DBConn
+
defer testutils.ClearData()
aliceUser := testutils.SetupUserData()
@@ -403,16 +402,17 @@ func TestSignout(t *testing.T) {
UserID: aliceUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session1), "preparing session1")
+ testutils.MustExec(t, testutils.DB.Save(&session1), "preparing session1")
session2 := database.Session{
Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=",
UserID: anotherUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, db.Save(&session2), "preparing session2")
+ testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session2")
// Setup
server := MustNewServer(t, &App{
+
Clock: clock.NewMock(),
})
defer server.Close()
@@ -426,9 +426,9 @@ func TestSignout(t *testing.T) {
var sessionCount int
var postSession1, postSession2 database.Session
- testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
- testutils.MustExec(t, db.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1")
- testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2")
+ testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, testutils.DB.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1")
+ testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2")
// two existing sessions should remain
assert.Equal(t, sessionCount, 2, "sessionCount mismatch")
diff --git a/pkg/server/handlers/v3_books.go b/pkg/server/handlers/v3_books.go
index 7be082d5..2cac6d61 100644
--- a/pkg/server/handlers/v3_books.go
+++ b/pkg/server/handlers/v3_books.go
@@ -24,11 +24,12 @@ import (
"net/http"
"net/url"
+ "github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/presenters"
- "github.com/dnote/dnote/pkg/server/database"
"github.com/gorilla/mux"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@@ -69,10 +70,8 @@ func (a *App) CreateBook(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
-
var bookCount int
- err = db.Model(database.Book{}).
+ err = a.DB.Model(database.Book{}).
Where("user_id = ? AND label = ?", user.ID, params.Name).
Count(&bookCount).Error
if err != nil {
@@ -84,7 +83,7 @@ func (a *App) CreateBook(w http.ResponseWriter, r *http.Request) {
return
}
- book, err := operations.CreateBook(user, a.Clock, params.Name)
+ book, err := operations.CreateBook(a.DB, user, a.Clock, params.Name)
if err != nil {
HandleError(w, "inserting book", err, http.StatusInternalServerError)
}
@@ -100,9 +99,7 @@ func (a *App) BooksOptions(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Version")
}
-func respondWithBooks(userID int, query url.Values, w http.ResponseWriter) {
- db := database.DBConn
-
+func respondWithBooks(db *gorm.DB, userID int, query url.Values, w http.ResponseWriter) {
var books []database.Book
conn := db.Where("user_id = ? AND NOT deleted", userID).Order("label ASC")
name := query.Get("name")
@@ -132,19 +129,6 @@ func respondWithBooks(userID int, query url.Values, w http.ResponseWriter) {
respondJSON(w, http.StatusOK, presentedBooks)
}
-// GetDemoBooks returns books for demo
-func (a *App) GetDemoBooks(w http.ResponseWriter, r *http.Request) {
- demoUserID, err := helpers.GetDemoUserID()
- if err != nil {
- HandleError(w, "finding demo user", err, http.StatusInternalServerError)
- return
- }
-
- query := r.URL.Query()
-
- respondWithBooks(demoUserID, query, w)
-}
-
// GetBooks returns books for the user
func (a *App) GetBooks(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
@@ -154,7 +138,7 @@ func (a *App) GetBooks(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
- respondWithBooks(user.ID, query, w)
+ respondWithBooks(a.DB, user.ID, query, w)
}
// GetBook returns a book for the user
@@ -164,13 +148,11 @@ func (a *App) GetBook(w http.ResponseWriter, r *http.Request) {
return
}
- db := database.DBConn
-
vars := mux.Vars(r)
bookUUID := vars["bookUUID"]
var book database.Book
- conn := db.Where("uuid = ? AND user_id = ?", bookUUID, user.ID).First(&book)
+ conn := a.DB.Where("uuid = ? AND user_id = ?", bookUUID, user.ID).First(&book)
if conn.RecordNotFound() {
w.WriteHeader(http.StatusNotFound)
@@ -204,8 +186,7 @@ func (a *App) UpdateBook(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
uuid := vars["bookUUID"]
- db := database.DBConn
- tx := db.Begin()
+ tx := a.DB.Begin()
var book database.Book
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, uuid).First(&book).Error; err != nil {
@@ -250,8 +231,7 @@ func (a *App) DeleteBook(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
uuid := vars["bookUUID"]
- db := database.DBConn
- tx := db.Begin()
+ tx := a.DB.Begin()
var book database.Book
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, uuid).First(&book).Error; err != nil {
diff --git a/pkg/server/handlers/v3_books_test.go b/pkg/server/handlers/v3_books_test.go
index ab07c7a6..375275fc 100644
--- a/pkg/server/handlers/v3_books_test.go
+++ b/pkg/server/handlers/v3_books_test.go
@@ -26,23 +26,20 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
- "github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-}
-
func TestGetBooks(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
@@ -55,28 +52,28 @@ func TestGetBooks(t *testing.T) {
USN: 1123,
Deleted: false,
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
USN: 1125,
Deleted: false,
}
- testutils.MustExec(t, db.Save(&b2), "preparing b2")
+ testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: anotherUser.ID,
Label: "css",
USN: 1128,
Deleted: false,
}
- testutils.MustExec(t, db.Save(&b3), "preparing b3")
+ testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
b4 := database.Book{
UserID: user.ID,
Label: "",
USN: 1129,
Deleted: true,
}
- testutils.MustExec(t, db.Save(&b4), "preparing b4")
+ testutils.MustExec(t, testutils.DB.Save(&b4), "preparing b4")
// Execute
req := testutils.MakeReq(server, "GET", "/v3/books", "")
@@ -91,9 +88,9 @@ func TestGetBooks(t *testing.T) {
}
var b1Record, b2Record database.Book
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
- testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
- testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
expected := []presenters.Book{
{
@@ -116,12 +113,13 @@ func TestGetBooks(t *testing.T) {
}
func TestGetBooksByName(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
@@ -133,17 +131,17 @@ func TestGetBooksByName(t *testing.T) {
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, db.Save(&b2), "preparing b2")
+ testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: anotherUser.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b3), "preparing b3")
+ testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
// Execute
res := testutils.HTTPAuthDo(t, req, user)
@@ -157,7 +155,7 @@ func TestGetBooksByName(t *testing.T) {
}
var b1Record database.Book
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
expected := []presenters.Book{
{
@@ -201,39 +199,40 @@ func TestDeleteBook(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 58), "preparing user max_usn")
anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
+ testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
b1 := database.Book{
UserID: user.ID,
Label: "js",
USN: 1,
}
- testutils.MustExec(t, db.Save(&b1), "preparing a book data")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing a book data")
b2 := database.Book{
UserID: user.ID,
Label: tc.label,
USN: 2,
Deleted: tc.deleted,
}
- testutils.MustExec(t, db.Save(&b2), "preparing a book data")
+ testutils.MustExec(t, testutils.DB.Save(&b2), "preparing a book data")
b3 := database.Book{
UserID: anotherUser.ID,
Label: "linux",
USN: 3,
}
- testutils.MustExec(t, db.Save(&b3), "preparing a book data")
+ testutils.MustExec(t, testutils.DB.Save(&b3), "preparing a book data")
var n2Body string
if !tc.deleted {
@@ -250,7 +249,7 @@ func TestDeleteBook(t *testing.T) {
Body: "n1 content",
USN: 4,
}
- testutils.MustExec(t, db.Save(&n1), "preparing a note data")
+ testutils.MustExec(t, testutils.DB.Save(&n1), "preparing a note data")
n2 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
@@ -258,7 +257,7 @@ func TestDeleteBook(t *testing.T) {
USN: 5,
Deleted: tc.deleted,
}
- testutils.MustExec(t, db.Save(&n2), "preparing a note data")
+ testutils.MustExec(t, testutils.DB.Save(&n2), "preparing a note data")
n3 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
@@ -266,7 +265,7 @@ func TestDeleteBook(t *testing.T) {
USN: 6,
Deleted: tc.deleted,
}
- testutils.MustExec(t, db.Save(&n3), "preparing a note data")
+ testutils.MustExec(t, testutils.DB.Save(&n3), "preparing a note data")
n4 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
@@ -274,14 +273,14 @@ func TestDeleteBook(t *testing.T) {
USN: 7,
Deleted: true,
}
- testutils.MustExec(t, db.Save(&n4), "preparing a note data")
+ testutils.MustExec(t, testutils.DB.Save(&n4), "preparing a note data")
n5 := database.Note{
UserID: anotherUser.ID,
BookUUID: b3.UUID,
Body: "n5 content",
USN: 8,
}
- testutils.MustExec(t, db.Save(&n5), "preparing a note data")
+ testutils.MustExec(t, testutils.DB.Save(&n5), "preparing a note data")
endpoint := fmt.Sprintf("/v3/books/%s", b2.UUID)
req := testutils.MakeReq(server, "DELETE", endpoint, "")
@@ -299,17 +298,17 @@ func TestDeleteBook(t *testing.T) {
var userRecord database.User
var bookCount, noteCount int
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
- testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
- testutils.MustExec(t, db.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
- testutils.MustExec(t, db.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
- testutils.MustExec(t, db.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
- testutils.MustExec(t, db.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
- testutils.MustExec(t, db.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
- testutils.MustExec(t, db.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equal(t, bookCount, 3, "book count mismatch")
assert.Equal(t, noteCount, 5, "note count mismatch")
@@ -351,17 +350,18 @@ func TestDeleteBook(t *testing.T) {
}
func TestCreateBook(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
req := testutils.MakeReq(server, "POST", "/v3/books", `{"name": "js"}`)
req.Header.Set("Version", "0.1.1")
@@ -376,10 +376,10 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var bookCount, noteCount int
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, db.First(&bookRecord), "finding book")
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
maxUSN := 102
@@ -410,24 +410,25 @@ func TestCreateBook(t *testing.T) {
}
func TestCreateBookDuplicate(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UserID: user.ID,
Label: "js",
USN: 58,
}
- testutils.MustExec(t, db.Save(&b1), "preparing book data")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book data")
// Execute
req := testutils.MakeReq(server, "POST", "/v3/books", `{"name": "js"}`)
@@ -439,10 +440,10 @@ func TestCreateBookDuplicate(t *testing.T) {
var bookRecord database.Book
var bookCount, noteCount int
var userRecord database.User
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, db.First(&bookRecord), "finding book")
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 1, "book count mismatch")
assert.Equalf(t, noteCount, 0, "note count mismatch")
@@ -489,17 +490,18 @@ func TestUpdateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: tc.bookUUID,
@@ -507,15 +509,15 @@ func TestUpdateBook(t *testing.T) {
Label: tc.bookLabel,
Deleted: tc.bookDeleted,
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: b2UUID,
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b2), "preparing b2")
+ testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
- // Execute
+ // Executdb,e
endpoint := fmt.Sprintf("/v3/books/%s", tc.bookUUID)
req := testutils.MakeReq(server, "PATCH", endpoint, tc.payload)
res := testutils.HTTPAuthDo(t, req, user)
@@ -526,10 +528,10 @@ func TestUpdateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var noteCount, bookCount int
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 2, "book count mismatch")
assert.Equalf(t, noteCount, 0, "note count mismatch")
diff --git a/pkg/server/handlers/v3_notes.go b/pkg/server/handlers/v3_notes.go
index e6e62c78..65dc9f17 100644
--- a/pkg/server/handlers/v3_notes.go
+++ b/pkg/server/handlers/v3_notes.go
@@ -23,10 +23,10 @@ import (
"fmt"
"net/http"
+ "github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/presenters"
- "github.com/dnote/dnote/pkg/server/database"
"github.com/gorilla/mux"
"github.com/pkg/errors"
)
@@ -48,7 +48,6 @@ func validateUpdateNotePayload(p updateNotePayload) bool {
// UpdateNote updates note
func (a *App) UpdateNote(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
vars := mux.Vars(r)
noteUUID := vars["noteUUID"]
@@ -71,12 +70,12 @@ func (a *App) UpdateNote(w http.ResponseWriter, r *http.Request) {
}
var note database.Note
- if err := db.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).First(¬e).Error; err != nil {
+ if err := a.DB.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).First(¬e).Error; err != nil {
HandleError(w, "finding note", err, http.StatusInternalServerError)
return
}
- tx := db.Begin()
+ tx := a.DB.Begin()
note, err = operations.UpdateNote(tx, user, a.Clock, note, &operations.UpdateNoteParams{
BookUUID: params.BookUUID,
@@ -116,8 +115,6 @@ type deleteNoteResp struct {
// DeleteNote removes note
func (a *App) DeleteNote(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
-
vars := mux.Vars(r)
noteUUID := vars["noteUUID"]
@@ -128,12 +125,12 @@ func (a *App) DeleteNote(w http.ResponseWriter, r *http.Request) {
}
var note database.Note
- if err := db.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).Preload("Book").First(¬e).Error; err != nil {
+ if err := a.DB.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).Preload("Book").First(¬e).Error; err != nil {
HandleError(w, "finding note", err, http.StatusInternalServerError)
return
}
- tx := db.Begin()
+ tx := a.DB.Begin()
n, err := operations.DeleteNote(tx, user, note)
if err != nil {
@@ -193,13 +190,12 @@ func (a *App) CreateNote(w http.ResponseWriter, r *http.Request) {
}
var book database.Book
- db := database.DBConn
- if err := db.Where("uuid = ? AND user_id = ?", params.BookUUID, user.ID).First(&book).Error; err != nil {
+ if err := a.DB.Where("uuid = ? AND user_id = ?", params.BookUUID, user.ID).First(&book).Error; err != nil {
HandleError(w, "finding book", err, http.StatusInternalServerError)
return
}
- note, err := operations.CreateNote(user, a.Clock, params.BookUUID, params.Content, params.AddedOn, params.EditedOn, false)
+ note, err := operations.CreateNote(a.DB, user, a.Clock, params.BookUUID, params.Content, params.AddedOn, params.EditedOn, false)
if err != nil {
HandleError(w, "creating note", err, http.StatusInternalServerError)
return
diff --git a/pkg/server/handlers/v3_notes_test.go b/pkg/server/handlers/v3_notes_test.go
index 9b2f221b..524018e1 100644
--- a/pkg/server/handlers/v3_notes_test.go
+++ b/pkg/server/handlers/v3_notes_test.go
@@ -29,29 +29,26 @@ import (
"github.com/dnote/dnote/pkg/server/testutils"
)
-func init() {
- testutils.InitTestDB()
-}
-
func TestCreateNote(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UserID: user.ID,
Label: "js",
USN: 58,
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
// Execute
dat := fmt.Sprintf(`{"book_uuid": "%s", "content": "note content"}`, b1.UUID)
@@ -65,11 +62,11 @@ func TestCreateNote(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var bookCount, noteCount int
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, db.First(¬eRecord), "finding note")
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, testutils.DB.First(¬eRecord), "finding note")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 1, "book count mismatch")
assert.Equalf(t, noteCount, 1, "note count mismatch")
@@ -238,30 +235,31 @@ func TestUpdateNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: b1UUID,
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: b2UUID,
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b2), "preparing b2")
+ testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
note := database.Note{
UserID: user.ID,
@@ -271,7 +269,7 @@ func TestUpdateNote(t *testing.T) {
Deleted: tc.noteDeleted,
Public: tc.notePublic,
}
- testutils.MustExec(t, db.Save(¬e), "preparing note")
+ testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note")
// Execute
endpoint := fmt.Sprintf("/v3/notes/%s", note.UUID)
@@ -285,11 +283,11 @@ func TestUpdateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
var noteCount, bookCount int
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 2, "book count mismatch")
assert.Equalf(t, noteCount, 1, "note count mismatch")
@@ -333,24 +331,25 @@ func TestDeleteNote(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) {
+
defer testutils.ClearData()
- db := database.DBConn
// Setup
server := MustNewServer(t, &App{
- Clock: clock.NewMock(),
+
+Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 981), "preparing user max_usn")
b1 := database.Book{
UUID: b1UUID,
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
note := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
@@ -358,7 +357,7 @@ func TestDeleteNote(t *testing.T) {
Deleted: tc.deleted,
USN: tc.originalUSN,
}
- testutils.MustExec(t, db.Save(¬e), "preparing note")
+ testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note")
// Execute
endpoint := fmt.Sprintf("/v3/notes/%s", note.UUID)
@@ -372,11 +371,11 @@ func TestDeleteNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
var bookCount, noteCount int
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
- testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 1, "book count mismatch")
assert.Equalf(t, noteCount, 1, "note count mismatch")
diff --git a/pkg/server/handlers/v3_sync.go b/pkg/server/handlers/v3_sync.go
index 0994fa94..71468cc4 100644
--- a/pkg/server/handlers/v3_sync.go
+++ b/pkg/server/handlers/v3_sync.go
@@ -26,8 +26,8 @@ import (
"strconv"
"time"
- "github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
)
@@ -121,14 +121,12 @@ func (e *queryParamError) Error() string {
}
func (a *App) newFragment(userID, userMaxUSN, afterUSN, limit int) (SyncFragment, error) {
- db := database.DBConn
-
var notes []database.Note
- if err := db.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(¬es).Error; err != nil {
+ if err := a.DB.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(¬es).Error; err != nil {
return SyncFragment{}, nil
}
var books []database.Book
- if err := db.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(&books).Error; err != nil {
+ if err := a.DB.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(&books).Error; err != nil {
return SyncFragment{}, nil
}
diff --git a/pkg/server/helpers/helpers.go b/pkg/server/helpers/helpers.go
index 93a47355..56615f8b 100644
--- a/pkg/server/helpers/helpers.go
+++ b/pkg/server/helpers/helpers.go
@@ -19,8 +19,8 @@
package helpers
import (
- "github.com/dnote/dnote/pkg/server/database"
"github.com/google/uuid"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@@ -29,8 +29,7 @@ const (
)
// GetDemoUserID returns ID of the demo user
-func GetDemoUserID() (int, error) {
- db := database.DBConn
+func GetDemoUserID(db *gorm.DB) (int, error) {
result := struct {
UserID int
diff --git a/pkg/server/job/job.go b/pkg/server/job/job.go
index 5be2867e..dec4fa2c 100644
--- a/pkg/server/job/job.go
+++ b/pkg/server/job/job.go
@@ -24,6 +24,7 @@ import (
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/job/repetition"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/robfig/cron"
)
@@ -45,12 +46,12 @@ func checkEnvironment() error {
return nil
}
-func schedule(ch chan error) {
+func schedule(db *gorm.DB, ch chan error) {
cl := clock.New()
// Schedule jobs
c := cron.New()
- scheduleJob(c, "* * * * *", func() { repetition.Do(cl) })
+ scheduleJob(c, "* * * * *", func() { repetition.Do(db, cl) })
c.Start()
ch <- nil
@@ -60,13 +61,13 @@ func schedule(ch chan error) {
}
// Run starts the background tasks in a separate goroutine that runs forever
-func Run() error {
+func Run(db *gorm.DB) error {
if err := checkEnvironment(); err != nil {
return errors.Wrap(err, "checking environment variables")
}
ch := make(chan error)
- go schedule(ch)
+ go schedule(db, ch)
if err := <-ch; err != nil {
return errors.Wrap(err, "scheduling jobs")
}
diff --git a/pkg/server/job/repetition/main_test.go b/pkg/server/job/repetition/main_test.go
new file mode 100644
index 00000000..6555ee67
--- /dev/null
+++ b/pkg/server/job/repetition/main_test.go
@@ -0,0 +1,17 @@
+package repetition
+
+import (
+ "os"
+ "testing"
+
+ "github.com/dnote/dnote/pkg/server/testutils"
+)
+
+func TestMain(m *testing.M) {
+ testutils.InitTestDB()
+
+ code := m.Run()
+ testutils.ClearData()
+
+ os.Exit(code)
+}
diff --git a/pkg/server/job/repetition/repetition.go b/pkg/server/job/repetition/repetition.go
index 6c4df721..621f59e4 100644
--- a/pkg/server/job/repetition/repetition.go
+++ b/pkg/server/job/repetition/repetition.go
@@ -31,20 +31,29 @@ import (
"github.com/pkg/errors"
)
+// BuildEmailParams is the params for building an email
+type BuildEmailParams struct {
+ Now time.Time
+ User database.User
+ EmailAddr string
+ Digest database.Digest
+ Rule database.RepetitionRule
+}
+
// BuildEmail builds an email for the spaced repetition
-func BuildEmail(now time.Time, user database.User, emailAddr string, digest database.Digest, rule database.RepetitionRule) (*mailer.Email, error) {
- date := now.Format("Jan 02 2006")
- subject := fmt.Sprintf("%s %s", rule.Title, date)
- tok, err := mailer.GetToken(user, database.TokenTypeRepetition)
+func BuildEmail(db *gorm.DB, p BuildEmailParams) (*mailer.Email, error) {
+ date := p.Now.Format("Jan 02 2006")
+ subject := fmt.Sprintf("%s %s", p.Rule.Title, date)
+ tok, err := mailer.GetToken(db, p.User, database.TokenTypeRepetition)
if err != nil {
return nil, errors.Wrap(err, "getting email frequency token")
}
- t1 := now.AddDate(0, 0, -3).UnixNano()
- t2 := now.AddDate(0, 0, -7).UnixNano()
+ t1 := p.Now.AddDate(0, 0, -3).UnixNano()
+ t2 := p.Now.AddDate(0, 0, -7).UnixNano()
noteInfos := []mailer.DigestNoteInfo{}
- for _, note := range digest.Notes {
+ for _, note := range p.Digest.Notes {
var stage int
if note.AddedOn > t1 {
stage = 1
@@ -60,7 +69,7 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
bookCount := 0
bookMap := map[string]bool{}
- for _, n := range digest.Notes {
+ for _, n := range p.Digest.Notes {
if ok := bookMap[n.Book.Label]; !ok {
bookCount++
bookMap[n.Book.Label] = true
@@ -71,14 +80,14 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
Subject: subject,
NoteInfo: noteInfos,
ActiveBookCount: bookCount,
- ActiveNoteCount: len(digest.Notes),
+ ActiveNoteCount: len(p.Digest.Notes),
EmailSessionToken: tok.Value,
- RuleUUID: rule.UUID,
- RuleTitle: rule.Title,
+ RuleUUID: p.Rule.UUID,
+ RuleTitle: p.Rule.Title,
WebURL: os.Getenv("WebURL"),
}
- email := mailer.NewEmail("noreply@getdnote.com", []string{emailAddr}, subject)
+ email := mailer.NewEmail("noreply@getdnote.com", []string{p.EmailAddr}, subject)
if err := email.ParseTemplate(mailer.EmailTypeWeeklyDigest, tmplData); err != nil {
return nil, err
}
@@ -86,12 +95,11 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
return email, nil
}
-func getEligibleRules(now time.Time) ([]database.RepetitionRule, error) {
+func getEligibleRules(db *gorm.DB, now time.Time) ([]database.RepetitionRule, error) {
hour := now.Hour()
minute := now.Minute()
var ret []database.RepetitionRule
- db := database.DBConn
if err := db.
Where("users.cloud AND repetition_rules.hour = ? AND repetition_rules.minute = ? AND repetition_rules.enabled", hour, minute).
Joins("INNER JOIN users ON users.id = repetition_rules.user_id").
@@ -120,9 +128,7 @@ func build(tx *gorm.DB, rule database.RepetitionRule) (database.Digest, error) {
return digest, nil
}
-func notify(now time.Time, user database.User, digest database.Digest, rule database.RepetitionRule) error {
- db := database.DBConn
-
+func notify(db *gorm.DB, now time.Time, user database.User, digest database.Digest, rule database.RepetitionRule) error {
var account database.Account
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
return errors.Wrap(err, "getting account")
@@ -135,7 +141,13 @@ func notify(now time.Time, user database.User, digest database.Digest, rule data
return nil
}
- email, err := BuildEmail(now, user, account.Email.String, digest, rule)
+ email, err := BuildEmail(db, BuildEmailParams{
+ Now: now,
+ User: user,
+ EmailAddr: account.Email.String,
+ Digest: digest,
+ Rule: rule,
+ })
if err != nil {
return errors.Wrap(err, "making email")
}
@@ -185,12 +197,11 @@ func touchTimestamp(tx *gorm.DB, rule database.RepetitionRule, now time.Time) er
return nil
}
-func process(now time.Time, rule database.RepetitionRule) error {
+func process(db *gorm.DB, now time.Time, rule database.RepetitionRule) error {
log.WithFields(log.Fields{
"uuid": rule.UUID,
}).Info("processing repetition")
- db := database.DBConn
tx := db.Begin()
if !checkCooldown(now, rule) {
@@ -224,7 +235,7 @@ func process(now time.Time, rule database.RepetitionRule) error {
return errors.Wrap(err, "committing transaction")
}
- if err := notify(now, user, digest, rule); err != nil {
+ if err := notify(db, now, user, digest, rule); err != nil {
return errors.Wrap(err, "notifying user")
}
@@ -236,10 +247,10 @@ func process(now time.Time, rule database.RepetitionRule) error {
}
// Do creates spaced repetitions and delivers the results based on the rules
-func Do(c clock.Clock) error {
+func Do(db *gorm.DB, c clock.Clock) error {
now := c.Now().UTC()
- rules, err := getEligibleRules(now)
+ rules, err := getEligibleRules(db, now)
if err != nil {
return errors.Wrap(err, "getting eligible repetition rules")
}
@@ -251,7 +262,7 @@ func Do(c clock.Clock) error {
}).Info("processing rules")
for _, rule := range rules {
- if err := process(now, rule); err != nil {
+ if err := process(db, now, rule); err != nil {
log.WithFields(log.Fields{
"rule uuid": rule.UUID,
}).ErrorWrap(err, "Could not process the repetition rule")
diff --git a/pkg/server/job/repetition/repetition_test.go b/pkg/server/job/repetition/repetition_test.go
index 2c6d705a..ba02ee1b 100644
--- a/pkg/server/job/repetition/repetition_test.go
+++ b/pkg/server/job/repetition/repetition_test.go
@@ -29,24 +29,18 @@ import (
"github.com/dnote/dnote/pkg/server/testutils"
)
-func init() {
- testutils.InitTestDB()
-}
-
func assertLastActive(t *testing.T, ruleUUID string, lastActive int64) {
- db := database.DBConn
var rule database.RepetitionRule
- testutils.MustExec(t, db.Where("uuid = ?", ruleUUID).First(&rule), "finding rule1")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", ruleUUID).First(&rule), "finding rule1")
assert.Equal(t, rule.LastActive, lastActive, "LastActive mismatch")
}
func assertDigestCount(t *testing.T, rule database.RepetitionRule, expected int) {
- db := database.DBConn
var digestCount int
- testutils.MustExec(t, db.Model(&database.Digest{}).Where("rule_id = ? AND user_id = ?", rule.ID, rule.UserID).Count(&digestCount), "counting digest")
+ testutils.MustExec(t, testutils.DB.Model(&database.Digest{}).Where("rule_id = ? AND user_id = ?", rule.ID, rule.UserID).Count(&digestCount), "counting digest")
assert.Equal(t, digestCount, expected, "digest count mismatch")
}
@@ -74,68 +68,67 @@ func TestDo(t *testing.T) {
},
}
- db := database.DBConn
- testutils.MustExec(t, db.Save(&r1), "preparing rule1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
c := clock.NewMock()
// Test
// 1 day later
c.SetNow(time.Date(2009, time.November, 2, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(0))
assertDigestCount(t, r1, 0)
// 2 days later
c.SetNow(time.Date(2009, time.November, 3, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(0))
assertDigestCount(t, r1, 0)
// 3 days later - should be processed
c.SetNow(time.Date(2009, time.November, 4, 12, 1, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(0))
assertDigestCount(t, r1, 0)
c.SetNow(time.Date(2009, time.November, 4, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257336120000))
assertDigestCount(t, r1, 1)
c.SetNow(time.Date(2009, time.November, 4, 12, 3, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257336120000))
assertDigestCount(t, r1, 1)
// 4 day later
c.SetNow(time.Date(2009, time.November, 5, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257336120000))
assertDigestCount(t, r1, 1)
// 5 days later
c.SetNow(time.Date(2009, time.November, 6, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257336120000))
assertDigestCount(t, r1, 1)
// 6 days later - should be processed
c.SetNow(time.Date(2009, time.November, 7, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257595320000))
assertDigestCount(t, r1, 2)
// 7 days later
c.SetNow(time.Date(2009, time.November, 8, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257595320000))
assertDigestCount(t, r1, 2)
// 8 days later
c.SetNow(time.Date(2009, time.November, 9, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257595320000))
assertDigestCount(t, r1, 2)
// 9 days later - should be processed
c.SetNow(time.Date(2009, time.November, 10, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257854520000))
assertDigestCount(t, r1, 3)
})
@@ -177,15 +170,14 @@ func TestDo(t *testing.T) {
},
}
- db := database.DBConn
- testutils.MustExec(t, db.Save(&r1), "preparing rule1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 10, 12, 2, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
var rule database.RepetitionRule
- testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&rule), "finding rule1")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&rule), "finding rule1")
assert.Equal(t, rule.LastActive, time.Date(2009, time.November, 10, 12, 2, 0, 0, time.UTC).UnixNano()/int64(time.Millisecond), "LastActive mismsatch")
assert.Equal(t, rule.NextActive, time.Date(2009, time.November, 13, 12, 2, 0, 0, time.UTC).UnixNano()/int64(time.Millisecond), "NextActive mismsatch")
@@ -216,13 +208,12 @@ func TestDo_Disabled(t *testing.T) {
},
}
- db := database.DBConn
- testutils.MustExec(t, db.Save(&r1), "preparing rule1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 4, 12, 2, 0, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
// Test
assertLastActive(t, r1.UUID, int64(0))
@@ -241,40 +232,39 @@ func TestDo_BalancedStrategy(t *testing.T) {
}
setup := func() testData {
- db := database.DBConn
user := testutils.SetupUserData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, db.Save(&b2), "preparing b2")
+ testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: user.ID,
Label: "golang",
}
- testutils.MustExec(t, db.Save(&b3), "preparing b3")
+ testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
n1 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
}
- testutils.MustExec(t, db.Save(&n1), "preparing n1")
+ testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
n2 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
}
- testutils.MustExec(t, db.Save(&n2), "preparing n2")
+ testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
n3 := database.Note{
UserID: user.ID,
BookUUID: b3.UUID,
}
- testutils.MustExec(t, db.Save(&n3), "preparing n3")
+ testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
return testData{
User: user,
@@ -293,7 +283,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
// Set up
dat := setup()
- db := database.DBConn
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
r1 := database.RepetitionRule{
@@ -312,20 +301,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
UpdatedAt: t0,
},
}
- testutils.MustExec(t, db.Save(&r1), "preparing rule1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
// Test
assertLastActive(t, r1.UUID, int64(1257714000000))
assertDigestCount(t, r1, 1)
var repetition database.Digest
- testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
+ testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
sort.SliceStable(repetition.Notes, func(i, j int) bool {
n1 := repetition.Notes[i]
@@ -335,9 +324,9 @@ func TestDo_BalancedStrategy(t *testing.T) {
})
var n1Record, n2Record, n3Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
- testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
- testutils.MustExec(t, db.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
expected := []database.Note{n1Record, n2Record, n3Record}
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
})
@@ -348,7 +337,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
// Set up
dat := setup()
- db := database.DBConn
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
r1 := database.RepetitionRule{
@@ -368,20 +356,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
UpdatedAt: t0,
},
}
- testutils.MustExec(t, db.Save(&r1), "preparing rule1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 1, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
// Test
assertLastActive(t, r1.UUID, int64(1257714000000))
assertDigestCount(t, r1, 1)
var repetition database.Digest
- testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
+ testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
sort.SliceStable(repetition.Notes, func(i, j int) bool {
n1 := repetition.Notes[i]
@@ -391,8 +379,8 @@ func TestDo_BalancedStrategy(t *testing.T) {
})
var n2Record, n3Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
- testutils.MustExec(t, db.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
expected := []database.Note{n2Record, n3Record}
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
})
@@ -403,7 +391,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
// Set up
dat := setup()
- db := database.DBConn
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
r1 := database.RepetitionRule{
@@ -423,20 +410,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
UpdatedAt: t0,
},
}
- testutils.MustExec(t, db.Save(&r1), "preparing rule1")
+ testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC))
- Do(c)
+ Do(testutils.DB, c)
// Test
assertLastActive(t, r1.UUID, int64(1257714000000))
assertDigestCount(t, r1, 1)
var repetition database.Digest
- testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
+ testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
sort.SliceStable(repetition.Notes, func(i, j int) bool {
n1 := repetition.Notes[i]
@@ -446,8 +433,8 @@ func TestDo_BalancedStrategy(t *testing.T) {
})
var n1Record, n2Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
- testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
expected := []database.Note{n1Record, n2Record}
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
})
diff --git a/pkg/server/job/repetition/strategy.go b/pkg/server/job/repetition/strategy.go
index 7d340361..2f03dc41 100644
--- a/pkg/server/job/repetition/strategy.go
+++ b/pkg/server/job/repetition/strategy.go
@@ -27,8 +27,7 @@ import (
"github.com/pkg/errors"
)
-func getRuleBookIDs(ruleID int) ([]int, error) {
- db := database.DBConn
+func getRuleBookIDs(db *gorm.DB, ruleID int) ([]int, error) {
var ret []int
if err := db.Table("repetition_rule_books").Select("book_id").Where("repetition_rule_id = ?", ruleID).Pluck("book_id", &ret).Error; err != nil {
return nil, errors.Wrap(err, "querying book_ids")
@@ -37,11 +36,11 @@ func getRuleBookIDs(ruleID int) ([]int, error) {
return ret, nil
}
-func applyBookDomain(noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB, error) {
+func applyBookDomain(db *gorm.DB, noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB, error) {
ret := noteQuery
if rule.BookDomain != database.BookDomainAll {
- bookIDs, err := getRuleBookIDs(rule.ID)
+ bookIDs, err := getRuleBookIDs(db, rule.ID)
if err != nil {
return nil, errors.Wrap(err, "getting book_ids")
}
@@ -58,8 +57,8 @@ func applyBookDomain(noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB
return ret, nil
}
-func getNotes(conn *gorm.DB, rule database.RepetitionRule, dst *[]database.Note) error {
- c, err := applyBookDomain(conn, rule)
+func getNotes(db, conn *gorm.DB, rule database.RepetitionRule, dst *[]database.Note) error {
+ c, err := applyBookDomain(db, conn, rule)
if err != nil {
return errors.Wrap(err, "building query for book threahold 1")
}
@@ -79,16 +78,14 @@ func getBalancedNotes(db *gorm.DB, rule database.RepetitionRule) ([]database.Not
t2 := now.AddDate(0, 0, -7).UnixNano()
// Get notes into three buckets with different threshold values
- var stage1 []database.Note
- var stage2 []database.Note
- var stage3 []database.Note
- if err := getNotes(db.Where("notes.added_on > ?", t1), rule, &stage1); err != nil {
+ var stage1, stage2, stage3 []database.Note
+ if err := getNotes(db, db.Where("notes.added_on > ?", t1), rule, &stage1); err != nil {
return nil, errors.Wrap(err, "Failed to get notes with threshold 1")
}
- if err := getNotes(db.Where("notes.added_on > ? AND notes.added_on < ?", t2, t1), rule, &stage2); err != nil {
+ if err := getNotes(db, db.Where("notes.added_on > ? AND notes.added_on < ?", t2, t1), rule, &stage2); err != nil {
return nil, errors.Wrap(err, "Failed to get notes with threshold 2")
}
- if err := getNotes(db.Where("notes.added_on < ?", t2), rule, &stage3); err != nil {
+ if err := getNotes(db, db.Where("notes.added_on < ?", t2), rule, &stage3); err != nil {
return nil, errors.Wrap(err, "Failed to get notes with threshold 3")
}
diff --git a/pkg/server/job/repetition/strategy_test.go b/pkg/server/job/repetition/strategy_test.go
index 07ed7d92..7e8b2fd1 100644
--- a/pkg/server/job/repetition/strategy_test.go
+++ b/pkg/server/job/repetition/strategy_test.go
@@ -34,45 +34,43 @@ func init() {
func TestApplyBookDomain(t *testing.T) {
defer testutils.ClearData()
- db := database.DBConn
-
user := testutils.SetupUserData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, db.Save(&b2), "preparing b2")
+ testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: user.ID,
Label: "golang",
}
- testutils.MustExec(t, db.Save(&b3), "preparing b3")
+ testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
n1 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
}
- testutils.MustExec(t, db.Save(&n1), "preparing n1")
+ testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
n2 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
}
- testutils.MustExec(t, db.Save(&n2), "preparing n2")
+ testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
n3 := database.Note{
UserID: user.ID,
BookUUID: b3.UUID,
}
- testutils.MustExec(t, db.Save(&n3), "preparing n3")
+ testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
var n1Record, n2Record, n3Record database.Note
- testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1")
- testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2")
- testutils.MustExec(t, db.Where("uuid = ?", n3.UUID).First(&n3Record), "finding n3")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", n3.UUID).First(&n3Record), "finding n3")
t.Run("book domain all", func(t *testing.T) {
rule := database.RepetitionRule{
@@ -80,7 +78,7 @@ func TestApplyBookDomain(t *testing.T) {
BookDomain: database.BookDomainAll,
}
- conn, err := applyBookDomain(db, rule)
+ conn, err := applyBookDomain(testutils.DB, testutils.DB, rule)
if err != nil {
t.Fatal(errors.Wrap(err, "executing").Error())
}
@@ -98,9 +96,9 @@ func TestApplyBookDomain(t *testing.T) {
BookDomain: database.BookDomainExluding,
Books: []database.Book{b1},
}
- testutils.MustExec(t, db.Save(&rule), "preparing rule")
+ testutils.MustExec(t, testutils.DB.Save(&rule), "preparing rule")
- conn, err := applyBookDomain(db.Debug(), rule)
+ conn, err := applyBookDomain(testutils.DB, testutils.DB, rule)
if err != nil {
t.Fatal(errors.Wrap(err, "executing").Error())
}
diff --git a/pkg/server/mailer/templates/main.go b/pkg/server/mailer/templates/main.go
index 867d0d8f..11a4b15a 100644
--- a/pkg/server/mailer/templates/main.go
+++ b/pkg/server/mailer/templates/main.go
@@ -25,15 +25,17 @@ import (
"time"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/dbconn"
"github.com/dnote/dnote/pkg/server/job/repetition"
"github.com/dnote/dnote/pkg/server/mailer"
+ "github.com/jinzhu/gorm"
"github.com/joho/godotenv"
_ "github.com/lib/pq"
"github.com/pkg/errors"
)
-func digestHandler(w http.ResponseWriter, r *http.Request) {
- db := database.DBConn
+func (c Context) digestHandler(w http.ResponseWriter, r *http.Request) {
+ db := c.DB
q := r.URL.Query()
digestUUID := q.Get("digest_uuid")
@@ -61,7 +63,13 @@ func digestHandler(w http.ResponseWriter, r *http.Request) {
}
now := time.Now()
- email, err := repetition.BuildEmail(now, user, "sung@getdnote.com", digest, rule)
+ email, err := repetition.BuildEmail(db, repetition.BuildEmailParams{
+ Now: now,
+ User: user,
+ EmailAddr: "sung@getdnote.com",
+ Digest: digest,
+ Rule: rule,
+ })
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -71,7 +79,7 @@ func digestHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(body))
}
-func emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
+func (c Context) emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
data := struct {
Subject string
Token string
@@ -90,7 +98,7 @@ func emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(body))
}
-func homeHandler(w http.ResponseWriter, r *http.Request) {
+func (c Context) homeHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Email development server is running."))
}
@@ -101,23 +109,29 @@ func init() {
}
}
+// Context is a context holding global information
+type Context struct {
+ DB *gorm.DB
+}
+
func main() {
- c := database.Config{
+ db := dbconn.Open(dbconn.Config{
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
- }
- database.Open(c)
- defer database.Close()
+ })
+ defer db.Close()
mailer.InitTemplates(nil)
log.Println("Email template development server running on http://127.0.0.1:2300")
- http.HandleFunc("/", homeHandler)
- http.HandleFunc("/digest", digestHandler)
- http.HandleFunc("/email-verification", emailVerificationHandler)
+ ctx := Context{DB: db}
+
+ http.HandleFunc("/", ctx.homeHandler)
+ http.HandleFunc("/digest", ctx.digestHandler)
+ http.HandleFunc("/email-verification", ctx.emailVerificationHandler)
log.Fatal(http.ListenAndServe(":2300", nil))
}
diff --git a/pkg/server/mailer/tokens.go b/pkg/server/mailer/tokens.go
index 6a01ae68..40a06f86 100644
--- a/pkg/server/mailer/tokens.go
+++ b/pkg/server/mailer/tokens.go
@@ -23,6 +23,7 @@ import (
"encoding/base64"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@@ -39,9 +40,7 @@ func generateRandomToken(bits int) (string, error) {
// GetToken returns an token of the given kind for the user
// by first looking up any unused record and creating one if none exists.
-func GetToken(user database.User, kind string) (database.Token, error) {
- db := database.DBConn
-
+func GetToken(db *gorm.DB, user database.User, kind string) (database.Token, error) {
var tok database.Token
conn := db.
Where("user_id = ? AND type =? AND used_at IS NULL", user.ID, kind).
diff --git a/pkg/server/main.go b/pkg/server/main.go
index 9615a653..d4856a31 100644
--- a/pkg/server/main.go
+++ b/pkg/server/main.go
@@ -27,10 +27,12 @@ import (
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/dbconn"
"github.com/dnote/dnote/pkg/server/handlers"
"github.com/dnote/dnote/pkg/server/job"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/web"
+ "github.com/jinzhu/gorm"
"github.com/gobuffalo/packr/v2"
"github.com/pkg/errors"
@@ -64,49 +66,70 @@ func initContext() web.Context {
}
}
-func initServer() (*http.ServeMux, error) {
- apiRouter, err := handlers.NewRouter(&handlers.App{
- Clock: clock.New(),
- StripeAPIBackend: nil,
- WebURL: os.Getenv("WebURL"),
- })
+func initServer(app handlers.App) (*http.ServeMux, error) {
+ apiRouter, err := handlers.NewRouter(&app)
if err != nil {
return nil, errors.Wrap(err, "initializing router")
}
- ctx := initContext()
+ webCtx := initContext()
+ webHandlers := web.Init(webCtx)
mux := http.NewServeMux()
mux.Handle("/api/", http.StripPrefix("/api", apiRouter))
- mux.Handle("/static/", web.GetStaticHandler(ctx.StaticFileSystem))
- mux.HandleFunc("/service-worker.js", web.GetSWHandler(ctx.ServiceWorkerJs))
- mux.HandleFunc("/robots.txt", web.GetRobotsHandler(ctx.RobotsTxt))
- mux.HandleFunc("/", web.GetRootHandler(ctx.IndexHTML))
+ mux.Handle("/static/", webHandlers.GetStatic)
+ mux.HandleFunc("/service-worker.js", webHandlers.GetServiceWorker)
+ mux.HandleFunc("/robots.txt", webHandlers.GetRobots)
+ mux.HandleFunc("/", webHandlers.GetRoot)
return mux, nil
}
-func startCmd() {
- mailer.InitTemplates(nil)
+func initDB() *gorm.DB {
+ var skipSSL bool
+ if os.Getenv("GO_ENV") != "PRODUCTION" || os.Getenv("DB_NOSSL") != "" {
+ skipSSL = true
+ } else {
+ skipSSL = false
+ }
- database.Open(database.Config{
+ db := dbconn.Open(dbconn.Config{
+ SkipSSL: skipSSL,
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
})
- database.InitSchema()
- defer database.Close()
+ database.InitSchema(db)
- if err := database.Migrate(); err != nil {
+ return db
+}
+
+func initApp(db *gorm.DB) handlers.App {
+ return handlers.App{
+ DB: db,
+ Clock: clock.New(),
+ StripeAPIBackend: nil,
+ WebURL: os.Getenv("WebURL"),
+ }
+}
+
+func startCmd() {
+ db := initDB()
+ defer db.Close()
+
+ app := initApp(db)
+ mailer.InitTemplates(nil)
+
+ if err := database.Migrate(app.DB); err != nil {
panic(errors.Wrap(err, "running migrations"))
}
- if err := job.Run(); err != nil {
+ if err := job.Run(db); err != nil {
panic(errors.Wrap(err, "running job"))
}
- srv, err := initServer()
+ srv, err := initServer(app)
if err != nil {
panic(errors.Wrap(err, "initializing server"))
}
diff --git a/pkg/server/operations/books.go b/pkg/server/operations/books.go
index cb3bc77d..6408368b 100644
--- a/pkg/server/operations/books.go
+++ b/pkg/server/operations/books.go
@@ -20,15 +20,14 @@ package operations
import (
"github.com/dnote/dnote/pkg/clock"
- "github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/helpers"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
// CreateBook creates a book with the next usn and updates the user's max_usn
-func CreateBook(user database.User, clock clock.Clock, name string) (database.Book, error) {
- db := database.DBConn
+func CreateBook(db *gorm.DB, user database.User, clock clock.Clock, name string) (database.Book, error) {
tx := db.Begin()
nextUSN, err := incrementUserUSN(tx, user.ID)
diff --git a/pkg/server/operations/books_test.go b/pkg/server/operations/books_test.go
index acb0f58c..fba28276 100644
--- a/pkg/server/operations/books_test.go
+++ b/pkg/server/operations/books_test.go
@@ -29,10 +29,6 @@ import (
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-}
-
func TestCreateBook(t *testing.T) {
testCases := []struct {
userUSN int
@@ -59,17 +55,16 @@ func TestCreateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
- db := database.DBConn
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
c := clock.NewMock()
- book, err := CreateBook(user, c, tc.label)
+ book, err := CreateBook(testutils.DB, user, c, tc.label)
if err != nil {
t.Fatal(errors.Wrap(err, "creating book"))
}
@@ -78,13 +73,13 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
- if err := db.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
+ if err := testutils.DB.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
t.Fatal(errors.Wrap(err, "counting books"))
}
- if err := db.First(&bookRecord).Error; err != nil {
+ if err := testutils.DB.First(&bookRecord).Error; err != nil {
t.Fatal(errors.Wrap(err, "finding book"))
}
- if err := db.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
+ if err := testutils.DB.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
t.Fatal(errors.Wrap(err, "finding user"))
}
@@ -124,18 +119,17 @@ func TestDeleteBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
- db := database.DBConn
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
book := database.Book{UserID: user.ID, Label: "js", Deleted: false}
- testutils.MustExec(t, db.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
- tx := db.Begin()
+ tx := testutils.DB.Begin()
ret, err := DeleteBook(tx, user, book)
if err != nil {
tx.Rollback()
@@ -147,9 +141,9 @@ func TestDeleteBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
- testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, 1, "book count mismatch")
assert.Equal(t, bookRecord.UserID, user.ID, "book user_id mismatch")
@@ -202,20 +196,19 @@ func TestUpdateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
- db := database.DBConn
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
c := clock.NewMock()
b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel}
- testutils.MustExec(t, db.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
- tx := db.Begin()
+ tx := testutils.DB.Begin()
book, err := UpdateBook(tx, c, user, b, tc.payloadLabel)
if err != nil {
@@ -228,9 +221,9 @@ func TestUpdateBook(t *testing.T) {
var bookCount int
var bookRecord database.Book
var userRecord database.User
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
- testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, 1, "book count mismatch")
diff --git a/pkg/server/operations/helpers_test.go b/pkg/server/operations/helpers_test.go
index cb3bee1e..68e137e6 100644
--- a/pkg/server/operations/helpers_test.go
+++ b/pkg/server/operations/helpers_test.go
@@ -28,10 +28,6 @@ import (
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-}
-
func TestIncremenetUserUSN(t *testing.T) {
testCases := []struct {
maxUSN int
@@ -51,13 +47,12 @@ func TestIncremenetUserUSN(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
- db := database.DBConn
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
// execute
- tx := db.Begin()
+ tx := testutils.DB.Begin()
nextUSN, err := incrementUserUSN(tx, user.ID)
if err != nil {
t.Fatal(errors.Wrap(err, "incrementing the user usn"))
@@ -66,7 +61,7 @@ func TestIncremenetUserUSN(t *testing.T) {
// test
var userRecord database.User
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, userRecord.MaxUSN, tc.expectedMaxUSN, fmt.Sprintf("user max_usn mismatch for case %d", idx))
assert.Equal(t, nextUSN, tc.expectedMaxUSN, fmt.Sprintf("next_usn mismatch for case %d", idx))
diff --git a/pkg/server/operations/main_test.go b/pkg/server/operations/main_test.go
new file mode 100644
index 00000000..b2e7fab8
--- /dev/null
+++ b/pkg/server/operations/main_test.go
@@ -0,0 +1,17 @@
+package operations
+
+import (
+ "os"
+ "testing"
+
+ "github.com/dnote/dnote/pkg/server/testutils"
+)
+
+func TestMain(m *testing.M) {
+ testutils.InitTestDB()
+
+ code := m.Run()
+ testutils.ClearData()
+
+ os.Exit(code)
+}
diff --git a/pkg/server/operations/notes.go b/pkg/server/operations/notes.go
index 0395c6d3..ba1a9a39 100644
--- a/pkg/server/operations/notes.go
+++ b/pkg/server/operations/notes.go
@@ -29,8 +29,7 @@ import (
// CreateNote creates a note with the next usn and updates the user's max_usn.
// It returns the created note.
-func CreateNote(user database.User, clock clock.Clock, bookUUID, content string, addedOn *int64, editedOn *int64, public bool) (database.Note, error) {
- db := database.DBConn
+func CreateNote(db *gorm.DB, user database.User, clock clock.Clock, bookUUID, content string, addedOn *int64, editedOn *int64, public bool) (database.Note, error) {
tx := db.Begin()
nextUSN, err := incrementUserUSN(tx, user.ID)
@@ -163,14 +162,12 @@ func DeleteNote(tx *gorm.DB, user database.User, note database.Note) (database.N
}
// GetNote retrieves a note for the given user
-func GetNote(uuid string, user database.User) (database.Note, bool, error) {
+func GetNote(db *gorm.DB, uuid string, user database.User) (database.Note, bool, error) {
zeroNote := database.Note{}
if !helpers.ValidateUUID(uuid) {
return zeroNote, false, nil
}
- db := database.DBConn
-
conn := db.Where("notes.uuid = ? AND deleted = ?", uuid, false)
conn = database.PreloadNote(conn)
diff --git a/pkg/server/operations/notes_test.go b/pkg/server/operations/notes_test.go
index 17da24c6..f30aba96 100644
--- a/pkg/server/operations/notes_test.go
+++ b/pkg/server/operations/notes_test.go
@@ -30,10 +30,6 @@ import (
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-}
-
func TestCreateNote(t *testing.T) {
serverTime := time.Date(2017, time.March, 14, 21, 15, 0, 0, time.UTC)
mockClock := clock.NewMock()
@@ -79,19 +75,18 @@ func TestCreateNote(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
- db := database.DBConn
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
- testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
- tx := db.Begin()
- if _, err := CreateNote(user, mockClock, b1.UUID, "note content", tc.addedOn, tc.editedOn, false); err != nil {
+ tx := testutils.DB.Begin()
+ if _, err := CreateNote(testutils.DB, user, mockClock, b1.UUID, "note content", tc.addedOn, tc.editedOn, false); err != nil {
tx.Rollback()
t.Fatal(errors.Wrap(err, "deleting note"))
}
@@ -101,10 +96,10 @@ func TestCreateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
- testutils.MustExec(t, db.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, 1, "book count mismatch")
assert.Equal(t, noteCount, 1, "note count mismatch")
@@ -139,25 +134,24 @@ func TestUpdateNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
defer testutils.ClearData()
- db := database.DBConn
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
+ testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
- testutils.MustExec(t, db.Save(&b1), "preparing b1 for test case")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1 for test case")
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
- testutils.MustExec(t, db.Save(¬e), "preparing note for test case")
+ testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note for test case")
c := clock.NewMock()
content := "updated test content"
public := true
- tx := db.Begin()
+ tx := testutils.DB.Begin()
if _, err := UpdateNote(tx, user, c, note, &UpdateNoteParams{
Content: &content,
Public: &public,
@@ -171,10 +165,10 @@ func TestUpdateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
- testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes for test case")
- testutils.MustExec(t, db.First(¬eRecord), "finding note for test case")
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
+ testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes for test case")
+ testutils.MustExec(t, testutils.DB.First(¬eRecord), "finding note for test case")
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
expectedUSN := tc.userUSN + 1
assert.Equal(t, bookCount, 1, "book count mismatch")
@@ -211,21 +205,20 @@ func TestDeleteNote(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
- db := database.DBConn
user := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "testBook"}
- testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
- testutils.MustExec(t, db.Save(¬e), fmt.Sprintf("preparing note for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Save(¬e), fmt.Sprintf("preparing note for test case %d", idx))
- tx := db.Begin()
+ tx := testutils.DB.Begin()
ret, err := DeleteNote(tx, user, note)
if err != nil {
tx.Rollback()
@@ -237,9 +230,9 @@ func TestDeleteNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
- testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
- testutils.MustExec(t, db.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
- testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
+ testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, noteCount, 1, "note count mismatch")
@@ -261,14 +254,13 @@ func TestGetNote(t *testing.T) {
user := testutils.SetupUserData()
anotherUser := testutils.SetupUserData()
- db := database.DBConn
defer testutils.ClearData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
privateNote := database.Note{
UserID: user.ID,
@@ -277,7 +269,7 @@ func TestGetNote(t *testing.T) {
Deleted: false,
Public: false,
}
- testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
+ testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
UserID: user.ID,
@@ -286,11 +278,11 @@ func TestGetNote(t *testing.T) {
Deleted: false,
Public: true,
}
- testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote")
+ testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote")
var privateNoteRecord, publicNoteRecord database.Note
- testutils.MustExec(t, db.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote")
- testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote")
+ testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote")
testCases := []struct {
name string
@@ -338,7 +330,7 @@ func TestGetNote(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- note, ok, err := GetNote(tc.note.UUID, tc.user)
+ note, ok, err := GetNote(testutils.DB, tc.note.UUID, tc.user)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
@@ -352,14 +344,13 @@ func TestGetNote(t *testing.T) {
func TestGetNote_nonexistent(t *testing.T) {
user := testutils.SetupUserData()
- db := database.DBConn
defer testutils.ClearData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
n1UUID := "4fd19336-671e-4ff3-8f22-662b80e22edc"
n1 := database.Note{
@@ -370,10 +361,10 @@ func TestGetNote_nonexistent(t *testing.T) {
Deleted: false,
Public: false,
}
- testutils.MustExec(t, db.Save(&n1), "preparing n1")
+ testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
nonexistentUUID := "4fd19336-671e-4ff3-8f22-662b80e22edd"
- note, ok, err := GetNote(nonexistentUUID, user)
+ note, ok, err := GetNote(testutils.DB, nonexistentUUID, user)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
diff --git a/pkg/server/operations/subscriptions.go b/pkg/server/operations/subscriptions.go
index e24da365..6d51db2e 100644
--- a/pkg/server/operations/subscriptions.go
+++ b/pkg/server/operations/subscriptions.go
@@ -22,6 +22,7 @@ import (
"github.com/dnote/dnote/pkg/server/database"
"github.com/pkg/errors"
+ "github.com/jinzhu/gorm"
"github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/sub"
)
@@ -66,9 +67,7 @@ func ReactivateSub(subscriptionID string, user database.User) error {
}
// MarkUnsubscribed marks the user unsubscribed
-func MarkUnsubscribed(stripeCustomerID string) error {
- db := database.DBConn
-
+func MarkUnsubscribed(db *gorm.DB, stripeCustomerID string) error {
var user database.User
if err := db.Where("stripe_customer_id = ?", stripeCustomerID).First(&user).Error; err != nil {
return errors.Wrap(err, "finding user")
diff --git a/pkg/server/operations/users.go b/pkg/server/operations/users.go
index fd88dd18..c770b3ce 100644
--- a/pkg/server/operations/users.go
+++ b/pkg/server/operations/users.go
@@ -104,8 +104,7 @@ func createDefaultRepetitionRule(user database.User, tx *gorm.DB) error {
}
// CreateUser creates a user
-func CreateUser(email, password string) (database.User, error) {
- db := database.DBConn
+func CreateUser(db *gorm.DB, email, password string) (database.User, error) {
tx := db.Begin()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
diff --git a/pkg/server/permissions/permissions_test.go b/pkg/server/permissions/permissions_test.go
index a273d4df..c993d2fd 100644
--- a/pkg/server/permissions/permissions_test.go
+++ b/pkg/server/permissions/permissions_test.go
@@ -19,6 +19,7 @@
package permissions
import (
+ "os"
"testing"
"github.com/dnote/dnote/pkg/assert"
@@ -26,22 +27,26 @@ import (
"github.com/dnote/dnote/pkg/server/testutils"
)
-func init() {
+func TestMain(m *testing.M) {
testutils.InitTestDB()
+
+ code := m.Run()
+ testutils.ClearData()
+
+ os.Exit(code)
}
func TestViewNote(t *testing.T) {
user := testutils.SetupUserData()
anotherUser := testutils.SetupUserData()
- db := database.DBConn
defer testutils.ClearData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
privateNote := database.Note{
UserID: user.ID,
@@ -50,7 +55,7 @@ func TestViewNote(t *testing.T) {
Deleted: false,
Public: false,
}
- testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
+ testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
UserID: user.ID,
@@ -59,7 +64,7 @@ func TestViewNote(t *testing.T) {
Deleted: false,
Public: true,
}
- testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote")
+ testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote")
t.Run("owner accessing private note", func(t *testing.T) {
result := ViewNote(&user, privateNote)
diff --git a/pkg/server/testutils/main.go b/pkg/server/testutils/main.go
index 2621ba2d..1dcd0377 100644
--- a/pkg/server/testutils/main.go
+++ b/pkg/server/testutils/main.go
@@ -32,6 +32,7 @@ import (
"time"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/dbconn"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stripe/stripe-go"
@@ -42,31 +43,67 @@ func init() {
rand.Seed(time.Now().UnixNano())
}
+// DB is the database connection to a test database
+var DB *gorm.DB
+
// InitTestDB establishes connection pool with the test database specified by
// the environment variable configuration and initalizes a new schema
func InitTestDB() {
- c := database.Config{
+ db := dbconn.Open(dbconn.Config{
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
+ })
+ database.InitSchema(db)
+
+ DB = db
+}
+
+// ClearData deletes all records from the database
+func ClearData() {
+ if err := DB.Delete(&database.Book{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear books"))
+ }
+ if err := DB.Delete(&database.Note{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear notes"))
+ }
+ if err := DB.Delete(&database.Notification{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear notifications"))
+ }
+ if err := DB.Delete(&database.User{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear users"))
+ }
+ if err := DB.Delete(&database.Account{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear accounts"))
+ }
+ if err := DB.Delete(&database.Token{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear reset_tokens"))
+ }
+ if err := DB.Delete(&database.EmailPreference{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear reset_tokens"))
+ }
+ if err := DB.Delete(&database.Session{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear sessions"))
+ }
+ if err := DB.Delete(&database.Digest{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear digests"))
+ }
+ if err := DB.Delete(&database.RepetitionRule{}).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to clear digests"))
}
- database.Open(c)
- database.InitSchema()
}
// SetupUserData creates and returns a new user for testing purposes
func SetupUserData() database.User {
- db := database.DBConn
-
user := database.User{
APIKey: "test-api-key",
Name: "user-name",
Cloud: true,
}
- if err := db.Save(&user).Error; err != nil {
+ if err := DB.Save(&user).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare user"))
}
@@ -75,8 +112,6 @@ func SetupUserData() database.User {
// SetupAccountData creates and returns a new account for the user
func SetupAccountData(user database.User, email, password string) database.Account {
- db := database.DBConn
-
account := database.Account{
UserID: user.ID,
}
@@ -90,7 +125,7 @@ func SetupAccountData(user database.User, email, password string) database.Accou
}
account.Password = database.ToNullString(string(hashedPassword))
- if err := db.Save(&account).Error; err != nil {
+ if err := DB.Save(&account).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare account"))
}
@@ -99,8 +134,6 @@ func SetupAccountData(user database.User, email, password string) database.Accou
// SetupClassicAccountData creates and returns a new account for the user
func SetupClassicAccountData(user database.User, email string) database.Account {
- db := database.DBConn
-
// email: alice@example.com
// password: pass1234
// masterKey: WbUvagj9O6o1Z+4+7COjo7Uqm4MD2QE9EWFXne8+U+8=
@@ -117,7 +150,7 @@ func SetupClassicAccountData(user database.User, email string) database.Account
account.Email = database.ToNullString(email)
}
- if err := db.Save(&account).Error; err != nil {
+ if err := DB.Save(&account).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare account"))
}
@@ -126,14 +159,12 @@ func SetupClassicAccountData(user database.User, email string) database.Account
// SetupSession creates and returns a new user session
func SetupSession(t *testing.T, user database.User) database.Session {
- db := database.DBConn
-
session := database.Session{
Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- if err := db.Save(&session).Error; err != nil {
+ if err := DB.Save(&session).Error; err != nil {
t.Fatal(errors.Wrap(err, "Failed to prepare user"))
}
@@ -142,56 +173,18 @@ func SetupSession(t *testing.T, user database.User) database.Session {
// SetupEmailPreferenceData creates and returns a new email frequency for a user
func SetupEmailPreferenceData(user database.User, digestWeekly bool) database.EmailPreference {
- db := database.DBConn
-
frequency := database.EmailPreference{
UserID: user.ID,
DigestWeekly: digestWeekly,
}
- if err := db.Save(&frequency).Error; err != nil {
+ if err := DB.Save(&frequency).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare email frequency"))
}
return frequency
}
-// ClearData deletes all records from the database
-func ClearData() {
- db := database.DBConn
-
- if err := db.Delete(&database.Book{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear books"))
- }
- if err := db.Delete(&database.Note{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear notes"))
- }
- if err := db.Delete(&database.Notification{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear notifications"))
- }
- if err := db.Delete(&database.User{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear users"))
- }
- if err := db.Delete(&database.Account{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear accounts"))
- }
- if err := db.Delete(&database.Token{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear reset_tokens"))
- }
- if err := db.Delete(&database.EmailPreference{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear reset_tokens"))
- }
- if err := db.Delete(&database.Session{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear sessions"))
- }
- if err := db.Delete(&database.Digest{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear digests"))
- }
- if err := db.Delete(&database.RepetitionRule{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear digests"))
- }
-}
-
// HTTPDo makes an HTTP request and returns a response
func HTTPDo(t *testing.T, req *http.Request) *http.Response {
hc := http.Client{
@@ -213,8 +206,6 @@ func HTTPDo(t *testing.T, req *http.Request) *http.Response {
// HTTPAuthDo makes an HTTP request with an appropriate authorization header for a user
func HTTPAuthDo(t *testing.T, req *http.Request, user database.User) *http.Response {
- db := database.DBConn
-
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
t.Fatal(errors.Wrap(err, "reading random bits"))
@@ -225,7 +216,7 @@ func HTTPAuthDo(t *testing.T, req *http.Request, user database.User) *http.Respo
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 10 * 24),
}
- if err := db.Save(&session).Error; err != nil {
+ if err := DB.Save(&session).Error; err != nil {
t.Fatal(errors.Wrap(err, "Failed to prepare user"))
}
diff --git a/pkg/server/tmpl/app.go b/pkg/server/tmpl/app.go
index abbc53b7..770ef6c4 100644
--- a/pkg/server/tmpl/app.go
+++ b/pkg/server/tmpl/app.go
@@ -24,6 +24,7 @@ import (
"net/http"
"regexp"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@@ -58,8 +59,8 @@ func NewAppShell(content []byte) (AppShell, error) {
}
// Execute executes the index template
-func (a AppShell) Execute(r *http.Request) ([]byte, error) {
- data, err := a.getData(r)
+func (a AppShell) Execute(r *http.Request, db *gorm.DB) ([]byte, error) {
+ data, err := a.getData(db, r)
if err != nil {
return nil, errors.Wrap(err, "getting data")
}
@@ -72,11 +73,11 @@ func (a AppShell) Execute(r *http.Request) ([]byte, error) {
return buf.Bytes(), nil
}
-func (a AppShell) getData(r *http.Request) (tmplData, error) {
+func (a AppShell) getData(db *gorm.DB, r *http.Request) (tmplData, error) {
path := r.URL.Path
if ok, params := matchPath(path, notesPathRegex); ok {
- p, err := a.newNotePage(r, params[0])
+ p, err := a.newNotePage(db, r, params[0])
if err != nil {
return tmplData{}, errors.Wrap(err, "instantiating note page")
}
diff --git a/pkg/server/tmpl/app_test.go b/pkg/server/tmpl/app_test.go
index 6431eb57..d0ea83b0 100644
--- a/pkg/server/tmpl/app_test.go
+++ b/pkg/server/tmpl/app_test.go
@@ -29,10 +29,6 @@ import (
"github.com/pkg/errors"
)
-func init() {
- testutils.InitTestDB()
-}
-
func TestAppShellExecute(t *testing.T) {
t.Run("home", func(t *testing.T) {
a, err := NewAppShell([]byte("
{{ .Title }}{{ .MetaTags }}"))
@@ -45,7 +41,7 @@ func TestAppShellExecute(t *testing.T) {
t.Fatal(errors.Wrap(err, "preparing request"))
}
- b, err := a.Execute(r)
+ b, err := a.Execute(r, testutils.DB)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
@@ -55,21 +51,20 @@ func TestAppShellExecute(t *testing.T) {
t.Run("note", func(t *testing.T) {
defer testutils.ClearData()
- db := database.DBConn
user := testutils.SetupUserData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, db.Save(&b1), "preparing b1")
+ testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
n1 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
Public: true,
Body: "n1 content",
}
- testutils.MustExec(t, db.Save(&n1), "preparing note")
+ testutils.MustExec(t, testutils.DB.Save(&n1), "preparing note")
a, err := NewAppShell([]byte("{{ .MetaTags }}"))
if err != nil {
@@ -82,7 +77,7 @@ func TestAppShellExecute(t *testing.T) {
t.Fatal(errors.Wrap(err, "preparing request"))
}
- b, err := a.Execute(r)
+ b, err := a.Execute(r, testutils.DB)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
diff --git a/pkg/server/tmpl/data.go b/pkg/server/tmpl/data.go
index b0c9790d..fcfce9b0 100644
--- a/pkg/server/tmpl/data.go
+++ b/pkg/server/tmpl/data.go
@@ -30,6 +30,7 @@ import (
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/handlers"
"github.com/dnote/dnote/pkg/server/operations"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@@ -51,13 +52,13 @@ type notePage struct {
T *template.Template
}
-func (a AppShell) newNotePage(r *http.Request, noteUUID string) (notePage, error) {
- user, _, err := handlers.AuthWithSession(r, nil)
+func (a AppShell) newNotePage(db *gorm.DB, r *http.Request, noteUUID string) (notePage, error) {
+ user, _, err := handlers.AuthWithSession(db, r, nil)
if err != nil {
return notePage{}, errors.Wrap(err, "authenticating with session")
}
- note, ok, err := operations.GetNote(noteUUID, user)
+ note, ok, err := operations.GetNote(db, noteUUID, user)
if !ok {
return notePage{}, ErrNotFound
diff --git a/pkg/server/tmpl/main_test.go b/pkg/server/tmpl/main_test.go
new file mode 100644
index 00000000..81f2b954
--- /dev/null
+++ b/pkg/server/tmpl/main_test.go
@@ -0,0 +1,17 @@
+package tmpl
+
+import (
+ "os"
+ "testing"
+
+ "github.com/dnote/dnote/pkg/server/testutils"
+)
+
+func TestMain(m *testing.M) {
+ testutils.InitTestDB()
+
+ code := m.Run()
+ testutils.ClearData()
+
+ os.Exit(code)
+}
diff --git a/pkg/server/web/handlers.go b/pkg/server/web/handlers.go
index 45814b52..a20663e8 100644
--- a/pkg/server/web/handlers.go
+++ b/pkg/server/web/handlers.go
@@ -24,20 +24,40 @@ import (
"github.com/dnote/dnote/pkg/server/handlers"
"github.com/dnote/dnote/pkg/server/tmpl"
+ "github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
// Context contains contents of web assets
type Context struct {
+ DB *gorm.DB
IndexHTML []byte
RobotsTxt []byte
ServiceWorkerJs []byte
StaticFileSystem http.FileSystem
}
-// GetRootHandler returns an HTTP handler that serves the app shell
-func GetRootHandler(b []byte) http.HandlerFunc {
- appShell, err := tmpl.NewAppShell(b)
+// Handlers are a group of web handlers
+type Handlers struct {
+ GetRoot http.HandlerFunc
+ GetRobots http.HandlerFunc
+ GetServiceWorker http.HandlerFunc
+ GetStatic http.Handler
+}
+
+// Init initializes the handlers
+func Init(c Context) Handlers {
+ return Handlers{
+ GetRoot: getRootHandler(c),
+ GetRobots: getRobotsHandler(c),
+ GetServiceWorker: getSWHandler(c),
+ GetStatic: getStaticHandler(c),
+ }
+}
+
+// getRootHandler returns an HTTP handler that serves the app shell
+func getRootHandler(c Context) http.HandlerFunc {
+ appShell, err := tmpl.NewAppShell(c.IndexHTML)
if err != nil {
panic(errors.Wrap(err, "initializing app shell"))
}
@@ -46,7 +66,7 @@ func GetRootHandler(b []byte) http.HandlerFunc {
// index.html must not be cached
w.Header().Set("Cache-Control", "no-cache")
- buf, err := appShell.Execute(r)
+ buf, err := appShell.Execute(r, c.DB)
if err != nil {
if errors.Cause(err) == tmpl.ErrNotFound {
handlers.RespondNotFound(w)
@@ -60,24 +80,25 @@ func GetRootHandler(b []byte) http.HandlerFunc {
}
}
-// GetRobotsHandler returns an HTTP handler that serves robots.txt
-func GetRobotsHandler(b []byte) http.HandlerFunc {
+// getRobotsHandler returns an HTTP handler that serves robots.txt
+func getRobotsHandler(c Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache")
- w.Write(b)
+ w.Write(c.RobotsTxt)
}
}
-// GetSWHandler returns an HTTP handler that serves service worker
-func GetSWHandler(b []byte) http.HandlerFunc {
+// getSWHandler returns an HTTP handler that serves service worker
+func getSWHandler(c Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Content-Type", "application/javascript")
- w.Write(b)
+ w.Write(c.ServiceWorkerJs)
}
}
-// GetStaticHandler returns an HTTP handler that serves static files from a filesystem
-func GetStaticHandler(root http.FileSystem) http.Handler {
+// getStaticHandler returns an HTTP handler that serves static files from a filesystem
+func getStaticHandler(c Context) http.Handler {
+ root := c.StaticFileSystem
return http.StripPrefix("/static/", http.FileServer(root))
}
diff --git a/scripts/server/makeDemoDigests/main.go b/scripts/server/makeDemoDigests/main.go
index b878a41a..07059661 100644
--- a/scripts/server/makeDemoDigests/main.go
+++ b/scripts/server/makeDemoDigests/main.go
@@ -19,26 +19,27 @@
package main
import (
- "github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/dbconn"
+ "github.com/dnote/dnote/pkg/server/helpers"
"os"
"time"
)
func main() {
- c := database.Config{
+ db, err := dbconn.Open(dbconn.Config{
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
+ })
+ if err != nil {
+ panic(err)
}
- database.Open(c)
- db := database.DBConn
tx := db.Begin()
-
- userID, err := helpers.GetDemoUserID()
+ userID, err := helpers.GetDemoUserID(db)
if err != nil {
panic(err)
}
diff --git a/scripts/server/test.sh b/scripts/server/test.sh
index 8422a517..5795c6ed 100755
--- a/scripts/server/test.sh
+++ b/scripts/server/test.sh
@@ -13,6 +13,7 @@ if [ "${WATCH-false}" == true ]; then
while inotifywait --exclude .swp -e modify -r .; do go test ./... -cover -p 1; done;
set -e
else
+ # go test ./... -cover -p 1
go test ./... -cover -p 1
fi
diff --git a/web/assets/robots.txt b/web/assets/robots.txt
index 1f53798b..c2a49f4f 100644
--- a/web/assets/robots.txt
+++ b/web/assets/robots.txt
@@ -1,2 +1,2 @@
User-agent: *
-Disallow: /
+Allow: /