Refactor to avoid global database variable (#313)

* Avoid global database

* Fix Twitter summary card

* Fix CLI test
This commit is contained in:
Sung Won Cho 2019-11-16 09:45:56 +08:00 committed by GitHub
commit bd97209af8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 1056 additions and 1058 deletions

View file

@ -19,11 +19,7 @@
package database
import (
"fmt"
"os"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
// Use postgres
_ "github.com/lib/pq"
@ -34,97 +30,13 @@ var (
MigrationTableName = "migrations"
)
// Config holds the connection configuration
type Config struct {
Host string
Port string
Name string
User string
Password string
}
// ErrConfigMissingHost is an error for an incomplete configuration missing the host
var ErrConfigMissingHost = errors.New("Host is empty")
// ErrConfigMissingPort is an error for an incomplete configuration missing the port
var ErrConfigMissingPort = errors.New("Port is empty")
// ErrConfigMissingName is an error for an incomplete configuration missing the name
var ErrConfigMissingName = errors.New("Name is empty")
// ErrConfigMissingUser is an error for an incomplete configuration missing the user
var ErrConfigMissingUser = errors.New("User is empty")
func validateConfig(c Config) error {
if c.Host == "" {
return ErrConfigMissingHost
}
if c.Port == "" {
return ErrConfigMissingPort
}
if c.Name == "" {
return ErrConfigMissingName
}
if c.User == "" {
return ErrConfigMissingUser
}
return nil
}
func getPGConnectionString(c Config) (string, error) {
if err := validateConfig(c); err != nil {
return "", errors.Wrap(err, "invalid database config")
}
var sslmode string
if os.Getenv("GO_ENV") == "PRODUCTION" && os.Getenv("DB_NOSSL") == "" {
sslmode = "require"
} else {
sslmode = "disable"
}
return fmt.Sprintf(
"sslmode=%s host=%s port=%s dbname=%s user=%s password=%s",
sslmode,
c.Host,
c.Port,
c.Name,
c.User,
c.Password,
), nil
}
var (
// DBConn is the connection handle for the database
DBConn *gorm.DB
)
// Open opens the connection with the database
func Open(c Config) {
connStr, err := getPGConnectionString(c)
if err != nil {
panic(err)
}
DBConn, err = gorm.Open("postgres", connStr)
if err != nil {
panic(err)
}
}
// Close closes database connection
func Close() {
DBConn.Close()
}
// InitSchema migrates database schema to reflect the latest model definition
func InitSchema() {
if err := DBConn.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil {
func InitSchema(db *gorm.DB) {
if err := db.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil {
panic(err)
}
if err := DBConn.AutoMigrate(
if err := db.AutoMigrate(
Note{},
Book{},
User{},

View file

@ -22,20 +22,20 @@ import (
"log"
"github.com/gobuffalo/packr/v2"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/rubenv/sql-migrate"
)
// Migrate runs the migrations
func Migrate() error {
func Migrate(db *gorm.DB) error {
migrations := &migrate.PackrMigrationSource{
Box: packr.New("migrations", "../database/migrations/"),
}
migrate.SetTable(MigrationTableName)
db := DBConn.DB()
n, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
n, err := migrate.Exec(db.DB(), "postgres", migrations, migrate.Up)
if err != nil {
return errors.Wrap(err, "running migrations")
}

View file

@ -23,7 +23,7 @@ import (
"fmt"
"os"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/dbconn"
"github.com/joho/godotenv"
"github.com/pkg/errors"
"github.com/rubenv/sql-migrate"
@ -43,20 +43,18 @@ func init() {
}
}
c := database.Config{
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
}
database.Open(c)
}
func main() {
flag.Parse()
db := database.DBConn
db := dbconn.Open(dbconn.Config{
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
})
migrations := &migrate.FileMigrationSource{
Dir: *migrationDir,

View file

@ -0,0 +1,85 @@
package dbconn
import (
"fmt"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
// Config holds the connection configuration
type Config struct {
SkipSSL bool
Host string
Port string
Name string
User string
Password string
}
// ErrConfigMissingHost is an error for an incomplete configuration missing the host
var ErrConfigMissingHost = errors.New("Host is empty")
// ErrConfigMissingPort is an error for an incomplete configuration missing the port
var ErrConfigMissingPort = errors.New("Port is empty")
// ErrConfigMissingName is an error for an incomplete configuration missing the name
var ErrConfigMissingName = errors.New("Name is empty")
// ErrConfigMissingUser is an error for an incomplete configuration missing the user
var ErrConfigMissingUser = errors.New("User is empty")
func validateConfig(c Config) error {
if c.Host == "" {
return ErrConfigMissingHost
}
if c.Port == "" {
return ErrConfigMissingPort
}
if c.Name == "" {
return ErrConfigMissingName
}
if c.User == "" {
return ErrConfigMissingUser
}
return nil
}
func getPGConnectionString(c Config) (string, error) {
if err := validateConfig(c); err != nil {
return "", errors.Wrap(err, "invalid database config")
}
var sslmode string
if c.SkipSSL {
sslmode = "disable"
} else {
sslmode = "require"
}
return fmt.Sprintf(
"sslmode=%s host=%s port=%s dbname=%s user=%s password=%s",
sslmode,
c.Host,
c.Port,
c.Name,
c.User,
c.Password,
), nil
}
// Open opens the connection with the database
func Open(c Config) *gorm.DB {
connStr, err := getPGConnectionString(c)
if err != nil {
panic(errors.Wrap(err, "getting connection string"))
}
conn, err := gorm.Open("postgres", connStr)
if err != nil {
panic(errors.Wrap(err, "opening database connection"))
}
return conn
}

View file

@ -16,7 +16,7 @@
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package database
package dbconn
import (
"github.com/dnote/dnote/pkg/assert"
@ -38,6 +38,17 @@ func TestValidateConfig(t *testing.T) {
},
expected: nil,
},
{
input: Config{
SkipSSL: true,
Host: "mockHost",
Port: "mockPort",
Name: "mockName",
User: "mockUser",
Password: "mockPassword",
},
expected: nil,
},
{
input: Config{
Host: "mockHost",

View file

@ -24,10 +24,10 @@ import (
"net/http"
"time"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
@ -60,10 +60,8 @@ func (a *App) getMe(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
var account database.Account
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
@ -76,7 +74,7 @@ func (a *App) getMe(w http.ResponseWriter, r *http.Request) {
User: session,
}
tx := db.Begin()
tx := a.DB.Begin()
if err := operations.TouchLastLoginAt(user, tx); err != nil {
tx.Rollback()
// In case of an error, gracefully continue to avoid disturbing the service
@ -92,8 +90,6 @@ type createResetTokenPayload struct {
}
func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
var params createResetTokenPayload
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
http.Error(w, "invalid payload", http.StatusBadRequest)
@ -101,7 +97,7 @@ func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
conn := db.Where("email = ?", params.Email).First(&account)
conn := a.DB.Where("email = ?", params.Email).First(&account)
if conn.RecordNotFound() {
return
}
@ -127,7 +123,7 @@ func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
Type: database.TokenTypeResetPassword,
}
if err := db.Save(&token).Error; err != nil {
if err := a.DB.Save(&token).Error; err != nil {
HandleError(w, errors.Wrap(err, "saving token").Error(), nil, http.StatusInternalServerError)
return
}
@ -156,8 +152,6 @@ type resetPasswordPayload struct {
}
func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
var params resetPasswordPayload
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
http.Error(w, "invalid payload", http.StatusBadRequest)
@ -165,7 +159,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
}
var token database.Token
conn := db.Where("value = ? AND type =? AND used_at IS NULL", params.Token, database.TokenTypeResetPassword).First(&token)
conn := a.DB.Where("value = ? AND type =? AND used_at IS NULL", params.Token, database.TokenTypeResetPassword).First(&token)
if conn.RecordNotFound() {
http.Error(w, "invalid token", http.StatusBadRequest)
return
@ -186,7 +180,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
return
}
tx := db.Begin()
tx := a.DB.Begin()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(params.Password), bcrypt.DefaultCost)
if err != nil {
@ -196,7 +190,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
if err := db.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
if err := a.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
tx.Rollback()
HandleError(w, errors.Wrap(err, "finding user").Error(), nil, http.StatusInternalServerError)
return
@ -216,10 +210,10 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
tx.Commit()
var user database.User
if err := db.Where("id = ?", account.UserID).First(&user).Error; err != nil {
if err := a.DB.Where("id = ?", account.UserID).First(&user).Error; err != nil {
HandleError(w, errors.Wrap(err, "finding user").Error(), nil, http.StatusInternalServerError)
return
}
respondWithSession(w, user.ID, http.StatusOK)
respondWithSession(a.DB, w, user.ID, http.StatusOK)
}

View file

@ -31,8 +31,8 @@ import (
)
func TestGetMe(t *testing.T) {
testutils.InitTestDB()
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
@ -41,7 +41,7 @@ func TestGetMe(t *testing.T) {
defer server.Close()
u := testutils.SetupUserData()
testutils.SetupAccountData(u, "alice@example.com", "somepassword")
testutils.SetupAccountData( u, "alice@example.com", "somepassword")
dat := `{"email": "alice@example.com"}`
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
@ -53,23 +53,24 @@ func TestGetMe(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
var user database.User
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
}
func TestCreateResetToken(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
testutils.SetupAccountData(u, "alice@example.com", "somepassword")
testutils.SetupAccountData( u, "alice@example.com", "somepassword")
dat := `{"email": "alice@example.com"}`
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
@ -81,10 +82,10 @@ func TestCreateResetToken(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
var tokenCount int
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
var resetToken database.Token
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token")
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token")
assert.Equal(t, tokenCount, 1, "reset_token count mismatch")
assert.NotEqual(t, resetToken.Value, nil, "reset_token value mismatch")
@ -92,17 +93,18 @@ func TestCreateResetToken(t *testing.T) {
})
t.Run("nonexistent email", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
testutils.SetupAccountData(u, "alice@example.com", "somepassword")
testutils.SetupAccountData( u, "alice@example.com", "somepassword")
dat := `{"email": "bob@example.com"}`
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
@ -114,36 +116,37 @@ func TestCreateResetToken(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
var tokenCount int
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
assert.Equal(t, tokenCount, 0, "reset_token count mismatch")
})
}
func TestResetPassword(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword")
a := testutils.SetupAccountData( u, "alice@example.com", "oldpassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
otherTok := database.Token{
UserID: u.ID,
Value: "somerandomvalue",
Type: database.TokenTypeEmailVerification,
}
testutils.MustExec(t, db.Save(&otherTok), "preparing another token")
testutils.MustExec(t, testutils.DB.Save(&otherTok), "preparing another token")
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "newpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@ -156,9 +159,9 @@ func TestResetPassword(t *testing.T) {
var resetToken, verificationToken database.Token
var account database.Account
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
testutils.MustExec(t, db.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token")
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
testutils.MustExec(t, testutils.DB.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token")
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account")
assert.NotEqual(t, resetToken.UsedAt, nil, "reset_token UsedAt mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
@ -167,23 +170,24 @@ func TestResetPassword(t *testing.T) {
})
t.Run("nonexistent token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"token": "-ApMnyvpg59uOU5b-Kf5uQ==", "password": "oldpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@ -196,8 +200,8 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
@ -205,24 +209,25 @@ func TestResetPassword(t *testing.T) {
})
t.Run("expired token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@ -235,24 +240,25 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
})
t.Run("used token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
usedAt := time.Now().Add(time.Hour * -11).UTC()
tok := database.Token{
@ -261,8 +267,8 @@ func TestResetPassword(t *testing.T) {
Type: database.TokenTypeResetPassword,
UsedAt: &usedAt,
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@ -275,8 +281,8 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
if resetToken.UsedAt.Year() != usedAt.Year() ||
@ -290,24 +296,25 @@ func TestResetPassword(t *testing.T) {
})
t.Run("using wrong type token: email_verification", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
u := testutils.SetupUserData()
a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeEmailVerification,
}
testutils.MustExec(t, db.Save(&tok), "Failed to prepare reset_token")
testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
testutils.MustExec(t, testutils.DB.Save(&tok), "Failed to prepare reset_token")
testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
@ -320,8 +327,8 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")

View file

@ -23,11 +23,11 @@ import (
"net/http"
"github.com/dnote/dnote/pkg/server/crypt"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
@ -39,15 +39,13 @@ func (a *App) classicMigrate(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
var account database.Account
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
if err := db.Model(&account).
if err := a.DB.Model(&account).
Update(map[string]interface{}{
"salt": "",
"auth_key_hash": "",
@ -66,8 +64,6 @@ type PresigninResponse struct {
}
func (a *App) classicPresignin(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
q := r.URL.Query()
email := q.Get("email")
if email == "" {
@ -76,7 +72,7 @@ func (a *App) classicPresignin(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
conn := db.Where("email = ?", email).First(&account)
conn := a.DB.Where("email = ?", email).First(&account)
if !conn.RecordNotFound() && conn.Error != nil {
HandleError(w, "getting user", conn.Error, http.StatusInternalServerError)
return
@ -106,8 +102,6 @@ type classicSigninPayload struct {
}
func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
var params classicSigninPayload
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
@ -120,7 +114,7 @@ func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
conn := db.Where("email = ?", params.Email).First(&account)
conn := a.DB.Where("email = ?", params.Email).First(&account)
if conn.RecordNotFound() {
http.Error(w, ErrLoginFailure.Error(), http.StatusUnauthorized)
return
@ -138,7 +132,7 @@ func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
return
}
session, err := operations.CreateSession(db, account.UserID)
session, err := operations.CreateSession(a.DB, account.UserID)
if err != nil {
HandleError(w, "creating session", nil, http.StatusBadRequest)
return
@ -169,10 +163,8 @@ func (a *App) classicGetMe(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
var account database.Account
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
@ -229,8 +221,6 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
var params classicSetPasswordPayload
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
@ -238,7 +228,7 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "getting user", nil, http.StatusInternalServerError)
return
}
@ -249,7 +239,7 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
return
}
if err := db.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
if err := a.DB.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
http.Error(w, errors.Wrap(err, "updating password").Error(), http.StatusInternalServerError)
return
}
@ -265,8 +255,7 @@ func (a *App) classicGetNotes(w http.ResponseWriter, r *http.Request) {
}
var notes []database.Note
db := database.DBConn
if err := db.Where("user_id = ? AND encrypted = true", user.ID).Find(&notes).Error; err != nil {
if err := a.DB.Where("user_id = ? AND encrypted = true", user.ID).Find(&notes).Error; err != nil {
HandleError(w, "finding notes", err, http.StatusInternalServerError)
return
}

View file

@ -22,26 +22,16 @@ import (
"encoding/json"
"fmt"
"net/http"
"os"
"testing"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
mailer.InitTemplates(&templatePath)
}
func TestClassicPresignin(t *testing.T) {
db := database.DBConn
defer testutils.ClearData()
alice := database.Account{
@ -52,8 +42,8 @@ func TestClassicPresignin(t *testing.T) {
Email: database.ToNullString("bob@example.com"),
ClientKDFIteration: 200000,
}
testutils.MustExec(t, db.Save(&alice), "saving alice")
testutils.MustExec(t, db.Save(&bob), "saving bob")
testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
testutils.MustExec(t, testutils.DB.Save(&bob), "saving bob")
testCases := []struct {
email string
@ -121,12 +111,11 @@ func TestClassicPresignin_MissingParams(t *testing.T) {
}
func TestClassicSignin(t *testing.T) {
db := database.DBConn
defer testutils.ClearData()
user := testutils.SetupUserData()
alice := testutils.SetupClassicAccountData(user, "alice@example.com")
testutils.MustExec(t, db.Save(&alice), "saving alice")
testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
// Setup
server := MustNewServer(t, &App{
@ -145,8 +134,8 @@ func TestClassicSignin(t *testing.T) {
var sessionCount int
var session database.Session
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, db.First(&session), "getting session")
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, testutils.DB.First(&session), "getting session")
var got SessionResponse
if err := json.NewDecoder(res.Body).Decode(&got); err != nil {
@ -165,7 +154,7 @@ func TestClassicSignin(t *testing.T) {
}
func TestClassicSignin_Failure(t *testing.T) {
db := database.DBConn
defer testutils.ClearData()
//password: correctbattery
@ -183,8 +172,8 @@ func TestClassicSignin_Failure(t *testing.T) {
// plain authKey: DN4d/teaq1I2bVYZ7QWaah4Fu7q2y2N4yJNZk76hFHw=
AuthKeyHash: "fGOMHHAw9G7CH4Gv2EM1ZcZZklC1a55fS3QJ0qQVp4k=",
}
testutils.MustExec(t, db.Save(&alice), "saving alice")
testutils.MustExec(t, db.Save(&bob), "saving bob")
testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
testutils.MustExec(t, testutils.DB.Save(&bob), "saving bob")
testCases := []struct {
email string

View file

@ -25,13 +25,13 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/jinzhu/gorm"
)
func TestCheckHealth(t *testing.T) {
defer testutils.ClearData()
// Setup
server := MustNewServer(t, &App{
DB: &gorm.DB{},
Clock: clock.NewMock(),
})
defer server.Close()

View file

@ -0,0 +1,20 @@
package handlers
import (
"os"
"testing"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestMain(m *testing.M) {
testutils.InitTestDB()
templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
mailer.InitTemplates(&templatePath)
code := m.Run()
testutils.ClearData()
os.Exit(code)
}

View file

@ -86,9 +86,7 @@ func parseSearchQuery(q url.Values) string {
return escapeSearchQuery(searchStr)
}
func getNoteBaseQuery(noteUUID string, search string) *gorm.DB {
db := database.DBConn
func getNoteBaseQuery(db *gorm.DB, noteUUID string, search string) *gorm.DB {
var conn *gorm.DB
if search != "" {
conn = selectFTSFields(db, search, &ftsParams{HighlightAll: true})
@ -102,7 +100,7 @@ func getNoteBaseQuery(noteUUID string, search string) *gorm.DB {
}
func (a *App) getNote(w http.ResponseWriter, r *http.Request) {
user, _, err := AuthWithSession(r, nil)
user, _, err := AuthWithSession(a.DB, r, nil)
if err != nil {
HandleError(w, "authenticating", err, http.StatusInternalServerError)
return
@ -111,7 +109,7 @@ func (a *App) getNote(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
noteUUID := vars["noteUUID"]
note, ok, err := operations.GetNote(noteUUID, user)
note, ok, err := operations.GetNote(a.DB, noteUUID, user)
if !ok {
RespondNotFound(w)
return
@ -145,17 +143,17 @@ func (a *App) getNotes(w http.ResponseWriter, r *http.Request) {
}
query := r.URL.Query()
respondGetNotes(user.ID, query, w)
respondGetNotes(a.DB, user.ID, query, w)
}
func respondGetNotes(userID int, query url.Values, w http.ResponseWriter) {
func respondGetNotes(db *gorm.DB, userID int, query url.Values, w http.ResponseWriter) {
q, err := parseGetNotesQuery(query)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
conn := getNotesBaseQuery(userID, q)
conn := getNotesBaseQuery(db, userID, q)
var total int
if err := conn.Model(database.Note{}).Count(&total).Error; err != nil {
@ -274,9 +272,7 @@ func getDateBounds(year, month int) (int64, int64) {
return lower, upper
}
func getNotesBaseQuery(userID int, q getNotesQuery) *gorm.DB {
db := database.DBConn
func getNotesBaseQuery(db *gorm.DB, userID int, q getNotesQuery) *gorm.DB {
conn := db.Where(
"notes.user_id = ? AND notes.deleted = ? AND notes.encrypted = ?",
userID, false, q.Encrypted,
@ -317,8 +313,7 @@ func (a *App) legacyGetNotes(w http.ResponseWriter, r *http.Request) {
}
var notes []database.Note
db := database.DBConn
if err := db.Where("user_id = ? AND encrypted = true", user.ID).Find(&notes).Error; err != nil {
if err := a.DB.Where("user_id = ? AND encrypted = true", user.ID).Find(&notes).Error; err != nil {
HandleError(w, "finding notes", err, http.StatusInternalServerError)
return
}

View file

@ -28,16 +28,12 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
}
func getExpectedNotePayload(n database.Note, b database.Book, u database.User) presenters.Note {
return presenters.Note{
UUID: n.UUID,
@ -59,11 +55,12 @@ func getExpectedNotePayload(n database.Note, b database.Book, u database.User) p
}
func TestGetNotes(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -75,17 +72,17 @@ func TestGetNotes(t *testing.T) {
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, db.Save(&b2), "preparing b2")
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: anotherUser.ID,
Label: "css",
}
testutils.MustExec(t, db.Save(&b3), "preparing b3")
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
n1 := database.Note{
UserID: user.ID,
@ -95,7 +92,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, db.Save(&n1), "preparing n1")
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
n2 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
@ -104,7 +101,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 11, 22, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, db.Save(&n2), "preparing n2")
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
n3 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
@ -113,7 +110,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2017, time.January, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, db.Save(&n3), "preparing n3")
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
n4 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
@ -122,7 +119,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.September, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, db.Save(&n4), "preparing n4")
testutils.MustExec(t, testutils.DB.Save(&n4), "preparing n4")
n5 := database.Note{
UserID: anotherUser.ID,
BookUUID: b3.UUID,
@ -131,7 +128,7 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, db.Save(&n5), "preparing n5")
testutils.MustExec(t, testutils.DB.Save(&n5), "preparing n5")
n6 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
@ -140,7 +137,7 @@ func TestGetNotes(t *testing.T) {
Deleted: true,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, db.Save(&n6), "preparing n6")
testutils.MustExec(t, testutils.DB.Save(&n6), "preparing n6")
// Execute
req := testutils.MakeReq(server, "GET", "/notes?year=2018&month=8", "")
@ -155,8 +152,8 @@ func TestGetNotes(t *testing.T) {
}
var n2Record, n1Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
expected := GetNotesResponse{
Notes: []presenters.Note{
@ -170,11 +167,12 @@ func TestGetNotes(t *testing.T) {
}
func TestGetNote(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -186,7 +184,7 @@ func TestGetNote(t *testing.T) {
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
privateNote := database.Note{
UserID: user.ID,
@ -194,20 +192,20 @@ func TestGetNote(t *testing.T) {
Body: "privateNote content",
Public: false,
}
testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
Body: "publicNote content",
Public: true,
}
testutils.MustExec(t, db.Save(&publicNote), "preparing publicNote")
testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing publicNote")
deletedNote := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
Deleted: true,
}
testutils.MustExec(t, db.Save(&deletedNote), "preparing publicNote")
testutils.MustExec(t, testutils.DB.Save(&deletedNote), "preparing publicNote")
t.Run("owner accessing private note", func(t *testing.T) {
// Execute
@ -224,7 +222,7 @@ func TestGetNote(t *testing.T) {
}
var n1Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", privateNote.UUID).First(&n1Record), "finding n1Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", privateNote.UUID).First(&n1Record), "finding n1Record")
expected := getExpectedNotePayload(n1Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@ -245,7 +243,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@ -266,7 +264,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@ -304,7 +302,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")

View file

@ -23,9 +23,9 @@ import (
"net/http"
"time"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
"github.com/gorilla/mux"
"github.com/pkg/errors"
)
@ -45,9 +45,8 @@ func (a *App) getRepetitionRule(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
var repetitionRule database.RepetitionRule
if err := db.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").Find(&repetitionRule).Error; err != nil {
if err := a.DB.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").Find(&repetitionRule).Error; err != nil {
HandleError(w, "getting repetition rules", err, http.StatusInternalServerError)
return
}
@ -63,9 +62,8 @@ func (a *App) getRepetitionRules(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
var repetitionRules []database.RepetitionRule
if err := db.Where("user_id = ?", user.ID).Preload("Books").Order("last_active DESC").Find(&repetitionRules).Error; err != nil {
if err := a.DB.Where("user_id = ?", user.ID).Preload("Books").Order("last_active DESC").Find(&repetitionRules).Error; err != nil {
HandleError(w, "getting repetition rules", err, http.StatusInternalServerError)
return
}
@ -288,9 +286,8 @@ func (a *App) createRepetitionRule(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
var books []database.Book
if err := db.Where("user_id = ? AND uuid IN (?)", user.ID, params.GetBookUUIDs()).Find(&books).Error; err != nil {
if err := a.DB.Where("user_id = ? AND uuid IN (?)", user.ID, params.GetBookUUIDs()).Find(&books).Error; err != nil {
HandleError(w, "finding books", nil, http.StatusInternalServerError)
return
}
@ -313,7 +310,7 @@ func (a *App) createRepetitionRule(w http.ResponseWriter, r *http.Request) {
NoteCount: params.GetNoteCount(),
Enabled: params.GetEnabled(),
}
if err := db.Create(&record).Error; err != nil {
if err := a.DB.Create(&record).Error; err != nil {
HandleError(w, "creating a repetition rule", err, http.StatusInternalServerError)
return
}
@ -346,10 +343,8 @@ func (a *App) deleteRepetitionRule(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
repetitionRuleUUID := vars["repetitionRuleUUID"]
db := database.DBConn
var rule database.RepetitionRule
conn := db.Where("uuid = ? AND user_id = ?", repetitionRuleUUID, user.ID).First(&rule)
conn := a.DB.Where("uuid = ? AND user_id = ?", repetitionRuleUUID, user.ID).First(&rule)
if conn.RecordNotFound() {
http.Error(w, "Not found", http.StatusNotFound)
@ -359,7 +354,7 @@ func (a *App) deleteRepetitionRule(w http.ResponseWriter, r *http.Request) {
return
}
if err := db.Exec("DELETE from repetition_rules WHERE uuid = ?", rule.UUID).Error; err != nil {
if err := a.DB.Exec("DELETE from repetition_rules WHERE uuid = ?", rule.UUID).Error; err != nil {
HandleError(w, "deleting the repetition rule", err, http.StatusInternalServerError)
}
@ -382,8 +377,7 @@ func (a *App) updateRepetitionRule(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
tx := db.Begin()
tx := a.DB.Begin()
var repetitionRule database.RepetitionRule
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").First(&repetitionRule).Error; err != nil {

View file

@ -27,22 +27,19 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
}
func TestGetRepetitionRule(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -53,7 +50,7 @@ func TestGetRepetitionRule(t *testing.T) {
USN: 11,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing book1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
r1 := database.RepetitionRule{
Title: "Rule 1",
@ -66,7 +63,7 @@ func TestGetRepetitionRule(t *testing.T) {
Books: []database.Book{b1},
NoteCount: 5,
}
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
req := testutils.MakeReq(server, "GET", fmt.Sprintf("/repetition_rules/%s", r1.UUID), "")
@ -81,9 +78,9 @@ func TestGetRepetitionRule(t *testing.T) {
}
var r1Record database.RepetitionRule
testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
var b1Record database.Book
testutils.MustExec(t, db.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
expected := presenters.RepetitionRule{
UUID: r1Record.UUID,
@ -112,11 +109,12 @@ func TestGetRepetitionRule(t *testing.T) {
}
func TestGetRepetitionRules(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -127,7 +125,7 @@ func TestGetRepetitionRules(t *testing.T) {
USN: 11,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing book1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
r1 := database.RepetitionRule{
Title: "Rule 1",
@ -140,7 +138,7 @@ func TestGetRepetitionRules(t *testing.T) {
Books: []database.Book{b1},
NoteCount: 5,
}
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
r2 := database.RepetitionRule{
Title: "Rule 2",
@ -153,7 +151,7 @@ func TestGetRepetitionRules(t *testing.T) {
Books: []database.Book{},
NoteCount: 5,
}
testutils.MustExec(t, db.Save(&r2), "preparing rule2")
testutils.MustExec(t, testutils.DB.Save(&r2), "preparing rule2")
// Execute
req := testutils.MakeReq(server, "GET", "/repetition_rules", "")
@ -168,10 +166,10 @@ func TestGetRepetitionRules(t *testing.T) {
}
var r1Record, r2Record database.RepetitionRule
testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
testutils.MustExec(t, db.Where("uuid = ?", r2.UUID).First(&r2Record), "finding r2Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", r2.UUID).First(&r2Record), "finding r2Record")
var b1Record database.Book
testutils.MustExec(t, db.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
expected := []presenters.RepetitionRule{
{
@ -217,8 +215,8 @@ func TestGetRepetitionRules(t *testing.T) {
func TestCreateRepetitionRules(t *testing.T) {
t.Run("all books", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
c := clock.NewMock()
@ -226,6 +224,7 @@ func TestCreateRepetitionRules(t *testing.T) {
c.SetNow(t0)
server := MustNewServer(t, &App{
Clock: c,
})
defer server.Close()
@ -250,11 +249,11 @@ func TestCreateRepetitionRules(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
var ruleCount int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
var rule database.RepetitionRule
testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
assert.Equal(t, rule.Title, "Rule 1", "rule Title mismatch")
@ -275,8 +274,8 @@ func TestCreateRepetitionRules(t *testing.T) {
}
for _, tc := range bookDomainTestCases {
t.Run(tc, func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
c := clock.NewMock()
@ -284,6 +283,7 @@ func TestCreateRepetitionRules(t *testing.T) {
c.SetNow(t0)
server := MustNewServer(t, &App{
Clock: c,
})
defer server.Close()
@ -294,7 +294,7 @@ func TestCreateRepetitionRules(t *testing.T) {
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
// Execute
dat := fmt.Sprintf(`{
@ -314,14 +314,14 @@ func TestCreateRepetitionRules(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
var ruleCount int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
var rule database.RepetitionRule
testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
var b1Record database.Book
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
assert.Equal(t, rule.Title, "Rule 1", "rule Title mismatch")
@ -339,14 +339,15 @@ func TestCreateRepetitionRules(t *testing.T) {
}
func TestUpdateRepetitionRules(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
c := clock.NewMock()
t0 := time.Date(2009, time.November, 1, 2, 3, 4, 5, time.UTC)
c.SetNow(t0)
server := MustNewServer(t, &App{
Clock: c,
})
defer server.Close()
@ -367,13 +368,13 @@ func TestUpdateRepetitionRules(t *testing.T) {
Books: []database.Book{},
NoteCount: 20,
}
testutils.MustExec(t, db.Save(&r1), "preparing r1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
b1 := database.Book{
UserID: user.ID,
USN: 11,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing book1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
dat := fmt.Sprintf(`{
"title": "Rule 1 - edited",
@ -393,14 +394,14 @@ func TestUpdateRepetitionRules(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var totalRuleCount int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
assert.Equalf(t, totalRuleCount, 1, "reperition rule count mismatch")
var rule database.RepetitionRule
testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
var b1Record database.Book
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
assert.Equal(t, rule.Title, "Rule 1 - edited", "rule Title mismatch")
@ -416,11 +417,12 @@ func TestUpdateRepetitionRules(t *testing.T) {
}
func TestDeleteRepetitionRules(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -439,7 +441,7 @@ func TestDeleteRepetitionRules(t *testing.T) {
Books: []database.Book{},
NoteCount: 20,
}
testutils.MustExec(t, db.Save(&r1), "preparing r1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
r2 := database.RepetitionRule{
Title: "Rule 1",
@ -452,7 +454,7 @@ func TestDeleteRepetitionRules(t *testing.T) {
Books: []database.Book{},
NoteCount: 20,
}
testutils.MustExec(t, db.Save(&r2), "preparing r2")
testutils.MustExec(t, testutils.DB.Save(&r2), "preparing r2")
endpoint := fmt.Sprintf("/repetition_rules/%s", r1.UUID)
req := testutils.MakeReq(server, "DELETE", endpoint, "")
@ -462,11 +464,11 @@ func TestDeleteRepetitionRules(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var totalRuleCount int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
assert.Equalf(t, totalRuleCount, 1, "reperition rule count mismatch")
var r2Count int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Where("id = ?", r2.ID).Count(&r2Count), "counting r2")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Where("id = ?", r2.ID).Count(&r2Count), "counting r2")
assert.Equalf(t, r2Count, 1, "r2 count mismatch")
}
@ -541,11 +543,12 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case - create %d", idx), func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -560,13 +563,13 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
var ruleCount int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 0, "reperition rule count mismatch")
})
t.Run(fmt.Sprintf("test case %d - update", idx), func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
user := testutils.SetupUserData()
@ -581,15 +584,16 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
Books: []database.Book{},
NoteCount: 20,
}
testutils.MustExec(t, db.Save(&r1), "preparing r1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
b1 := database.Book{
UserID: user.ID,
USN: 11,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing book1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -602,7 +606,7 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
var ruleCount int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
})
}
@ -624,11 +628,12 @@ func TestCreateRepetitionRules_BadRequest(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -643,7 +648,7 @@ func TestCreateRepetitionRules_BadRequest(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
var ruleCount int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
assert.Equalf(t, ruleCount, 0, "reperition rule count mismatch")
})
}

View file

@ -31,6 +31,7 @@ import (
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stripe/stripe-go"
)
@ -63,28 +64,6 @@ func parseAuthHeader(h string) (authHeader, error) {
return parsed, nil
}
func legacyAuth(next http.HandlerFunc) http.HandlerFunc {
db := database.DBConn
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie("api_key")
if err != nil {
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
apiKey := c.Value
var user database.User
if db.Where("api_key = ?", apiKey).First(&user).RecordNotFound() {
http.Error(w, "Invalid API key", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), helpers.KeyUser, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// getSessionKeyFromCookie reads and returns a session key from the cookie sent by the
// request. If no session key is found, it returns an empty string
func getSessionKeyFromCookie(r *http.Request) (string, error) {
@ -138,8 +117,7 @@ func getCredential(r *http.Request) (string, error) {
}
// AuthWithSession performs user authentication with session
func AuthWithSession(r *http.Request, p *AuthMiddlewareParams) (database.User, bool, error) {
db := database.DBConn
func AuthWithSession(db *gorm.DB, r *http.Request, p *AuthMiddlewareParams) (database.User, bool, error) {
var user database.User
sessionKey, err := getCredential(r)
@ -174,8 +152,7 @@ func AuthWithSession(r *http.Request, p *AuthMiddlewareParams) (database.User, b
return user, true, nil
}
func authWithToken(r *http.Request, tokenType string, p *AuthMiddlewareParams) (database.User, database.Token, bool, error) {
db := database.DBConn
func authWithToken(db *gorm.DB, r *http.Request, tokenType string, p *AuthMiddlewareParams) (database.User, database.Token, bool, error) {
var user database.User
var token database.Token
@ -208,9 +185,9 @@ type AuthMiddlewareParams struct {
ProOnly bool
}
func auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
func (a *App) auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, ok, err := AuthWithSession(r, p)
user, ok, err := AuthWithSession(a.DB, r, p)
if !ok {
respondUnauthorized(w)
return
@ -231,9 +208,9 @@ func auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
})
}
func tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams) http.HandlerFunc {
func (a *App) tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, token, ok, err := authWithToken(r, tokenType, p)
user, token, ok, err := authWithToken(a.DB, r, tokenType, p)
if err != nil {
// log the error and continue
log.ErrorWrap(err, "authenticating with token")
@ -245,7 +222,7 @@ func tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams)
ctx = context.WithValue(ctx, helpers.KeyToken, token)
} else {
// If token-based auth fails, fall back to session-based auth
user, ok, err = AuthWithSession(r, p)
user, ok, err = AuthWithSession(a.DB, r, p)
if err != nil {
HandleError(w, "authenticating with session", err, http.StatusInternalServerError)
return
@ -325,6 +302,7 @@ func applyMiddleware(h http.HandlerFunc, rateLimit bool) http.Handler {
// App is an application configuration
type App struct {
DB *gorm.DB
Clock clock.Clock
StripeAPIBackend stripe.Backend
WebURL string
@ -334,6 +312,9 @@ func (a *App) validate() error {
if a.WebURL == "" {
return errors.New("WebURL is empty")
}
if a.DB == nil {
return errors.New("DB is empty")
}
return nil
}
@ -364,51 +345,50 @@ func NewRouter(app *App) (*mux.Router, error) {
var routes = []Route{
// internal
{"GET", "/health", app.checkHealth, false},
{"GET", "/me", auth(app.getMe, nil), true},
{"POST", "/verification-token", auth(app.createVerificationToken, nil), true},
{"GET", "/me", app.auth(app.getMe, nil), true},
{"POST", "/verification-token", app.auth(app.createVerificationToken, nil), true},
{"PATCH", "/verify-email", app.verifyEmail, true},
{"POST", "/reset-token", app.createResetToken, true},
{"PATCH", "/reset-password", app.resetPassword, true},
{"PATCH", "/account/profile", auth(app.updateProfile, nil), true},
{"PATCH", "/account/password", auth(app.updatePassword, nil), true},
{"GET", "/account/email-preference", tokenAuth(app.getEmailPreference, database.TokenTypeEmailPreference, nil), true},
{"PATCH", "/account/email-preference", tokenAuth(app.updateEmailPreference, database.TokenTypeEmailPreference, nil), true},
{"POST", "/subscriptions", auth(app.createSub, nil), true},
{"PATCH", "/subscriptions", auth(app.updateSub, nil), true},
{"PATCH", "/account/profile", app.auth(app.updateProfile, nil), true},
{"PATCH", "/account/password", app.auth(app.updatePassword, nil), true},
{"GET", "/account/email-preference", app.tokenAuth(app.getEmailPreference, database.TokenTypeEmailPreference, nil), true},
{"PATCH", "/account/email-preference", app.tokenAuth(app.updateEmailPreference, database.TokenTypeEmailPreference, nil), true},
{"POST", "/subscriptions", app.auth(app.createSub, nil), true},
{"PATCH", "/subscriptions", app.auth(app.updateSub, nil), true},
{"POST", "/webhooks/stripe", app.stripeWebhook, true},
{"GET", "/subscriptions", auth(app.getSub, nil), true},
{"GET", "/stripe_source", auth(app.getStripeSource, nil), true},
{"PATCH", "/stripe_source", auth(app.updateStripeSource, nil), true},
{"GET", "/notes", auth(app.getNotes, &proOnly), false},
{"GET", "/subscriptions", app.auth(app.getSub, nil), true},
{"GET", "/stripe_source", app.auth(app.getStripeSource, nil), true},
{"PATCH", "/stripe_source", app.auth(app.updateStripeSource, nil), true},
{"GET", "/notes", app.auth(app.getNotes, &proOnly), false},
{"GET", "/notes/{noteUUID}", app.getNote, true},
{"GET", "/calendar", auth(app.getCalendar, &proOnly), true},
{"GET", "/repetition_rules", auth(app.getRepetitionRules, &proOnly), true},
{"GET", "/repetition_rules/{repetitionRuleUUID}", tokenAuth(app.getRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
{"POST", "/repetition_rules", auth(app.createRepetitionRule, &proOnly), true},
{"PATCH", "/repetition_rules/{repetitionRuleUUID}", tokenAuth(app.updateRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
{"DELETE", "/repetition_rules/{repetitionRuleUUID}", auth(app.deleteRepetitionRule, &proOnly), true},
{"GET", "/calendar", app.auth(app.getCalendar, &proOnly), true},
{"GET", "/repetition_rules", app.auth(app.getRepetitionRules, &proOnly), true},
{"GET", "/repetition_rules/{repetitionRuleUUID}", app.tokenAuth(app.getRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
{"POST", "/repetition_rules", app.auth(app.createRepetitionRule, &proOnly), true},
{"PATCH", "/repetition_rules/{repetitionRuleUUID}", app.tokenAuth(app.updateRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
{"DELETE", "/repetition_rules/{repetitionRuleUUID}", app.auth(app.deleteRepetitionRule, &proOnly), true},
// migration of classic users
{"GET", "/classic/presignin", cors(app.classicPresignin), true},
{"POST", "/classic/signin", cors(app.classicSignin), true},
{"PATCH", "/classic/migrate", auth(app.classicMigrate, &proOnly), true},
{"GET", "/classic/notes", auth(app.classicGetNotes, nil), true},
{"PATCH", "/classic/set-password", auth(app.classicSetPassword, nil), true},
{"PATCH", "/classic/migrate", app.auth(app.classicMigrate, &proOnly), true},
{"GET", "/classic/notes", app.auth(app.classicGetNotes, nil), true},
{"PATCH", "/classic/set-password", app.auth(app.classicSetPassword, nil), true},
// v3
{"GET", "/v3/sync/fragment", cors(auth(app.GetSyncFragment, &proOnly)), true},
{"GET", "/v3/sync/state", cors(auth(app.GetSyncState, &proOnly)), true},
{"GET", "/v3/sync/fragment", cors(app.auth(app.GetSyncFragment, &proOnly)), true},
{"GET", "/v3/sync/state", cors(app.auth(app.GetSyncState, &proOnly)), true},
{"OPTIONS", "/v3/books", cors(app.BooksOptions), true},
{"GET", "/v3/books", cors(auth(app.GetBooks, &proOnly)), true},
{"GET", "/v3/books/{bookUUID}", cors(auth(app.GetBook, &proOnly)), true},
{"POST", "/v3/books", cors(auth(app.CreateBook, &proOnly)), true},
{"PATCH", "/v3/books/{bookUUID}", cors(auth(app.UpdateBook, &proOnly)), false},
{"DELETE", "/v3/books/{bookUUID}", cors(auth(app.DeleteBook, &proOnly)), false},
{"GET", "/v3/demo/books", app.GetDemoBooks, true},
{"GET", "/v3/books", cors(app.auth(app.GetBooks, &proOnly)), true},
{"GET", "/v3/books/{bookUUID}", cors(app.auth(app.GetBook, &proOnly)), true},
{"POST", "/v3/books", cors(app.auth(app.CreateBook, &proOnly)), true},
{"PATCH", "/v3/books/{bookUUID}", cors(app.auth(app.UpdateBook, &proOnly)), false},
{"DELETE", "/v3/books/{bookUUID}", cors(app.auth(app.DeleteBook, &proOnly)), false},
{"OPTIONS", "/v3/notes", cors(app.NotesOptions), true},
{"POST", "/v3/notes", cors(auth(app.CreateNote, &proOnly)), true},
{"PATCH", "/v3/notes/{noteUUID}", auth(app.UpdateNote, &proOnly), false},
{"DELETE", "/v3/notes/{noteUUID}", auth(app.DeleteNote, &proOnly), false},
{"POST", "/v3/notes", cors(app.auth(app.CreateNote, &proOnly)), true},
{"PATCH", "/v3/notes/{noteUUID}", app.auth(app.UpdateNote, &proOnly), false},
{"DELETE", "/v3/notes/{noteUUID}", app.auth(app.DeleteNote, &proOnly), false},
{"POST", "/v3/signin", cors(app.signin), true},
{"OPTIONS", "/v3/signout", cors(app.signoutOptions), true},
{"POST", "/v3/signout", cors(app.signout), true},

View file

@ -29,13 +29,10 @@ import (
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
}
func TestGetSessionKeyFromCookie(t *testing.T) {
testCases := []struct {
cookie *http.Cookie
@ -185,10 +182,8 @@ func TestGetCredential(t *testing.T) {
}
func TestAuthMiddleware(t *testing.T) {
defer testutils.ClearData()
// set up
db := database.DBConn
defer testutils.ClearData()
user := testutils.SetupUserData()
session := database.Session{
@ -196,18 +191,19 @@ func TestAuthMiddleware(t *testing.T) {
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session), "preparing session")
testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
session2 := database.Session{
Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
UserID: user.ID,
ExpiresAt: time.Now().Add(-time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session2), "preparing session")
testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
server := httptest.NewServer(auth(handler, nil))
app := App{DB: testutils.DB}
server := httptest.NewServer(app.auth(handler, nil))
defer server.Close()
t.Run("with header", func(t *testing.T) {
@ -300,24 +296,23 @@ func TestAuthMiddleware(t *testing.T) {
}
func TestAuthMiddleware_ProOnly(t *testing.T) {
defer testutils.ClearData()
// set up
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("cloud", false), "preparing session")
testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
session := database.Session{
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session), "preparing session")
testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
server := httptest.NewServer(auth(handler, &AuthMiddlewareParams{
app := App{DB: testutils.DB}
server := httptest.NewServer(app.auth(handler, &AuthMiddlewareParams{
ProOnly: true,
}))
defer server.Close()
@ -390,10 +385,8 @@ func TestAuthMiddleware_ProOnly(t *testing.T) {
}
func TestTokenAuthMiddleWare(t *testing.T) {
defer testutils.ClearData()
// set up
db := database.DBConn
defer testutils.ClearData()
user := testutils.SetupUserData()
tok := database.Token{
@ -401,18 +394,19 @@ func TestTokenAuthMiddleWare(t *testing.T) {
Type: database.TokenTypeEmailPreference,
Value: "xpwFnc0MdllFUePDq9DLeQ==",
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
session := database.Session{
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session), "preparing session")
testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
server := httptest.NewServer(tokenAuth(handler, database.TokenTypeEmailPreference, nil))
app := App{DB: testutils.DB}
server := httptest.NewServer(app.tokenAuth(handler, database.TokenTypeEmailPreference, nil))
defer server.Close()
t.Run("with token", func(t *testing.T) {
@ -521,30 +515,29 @@ func TestTokenAuthMiddleWare(t *testing.T) {
}
func TestTokenAuthMiddleWare_ProOnly(t *testing.T) {
defer testutils.ClearData()
// set up
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("cloud", false), "preparing session")
testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
tok := database.Token{
UserID: user.ID,
Type: database.TokenTypeEmailPreference,
Value: "xpwFnc0MdllFUePDq9DLeQ==",
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
session := database.Session{
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session), "preparing session")
testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
server := httptest.NewServer(tokenAuth(handler, database.TokenTypeEmailPreference, &AuthMiddlewareParams{
app := App{DB: testutils.DB}
server := httptest.NewServer(app.tokenAuth(handler, database.TokenTypeEmailPreference, &AuthMiddlewareParams{
ProOnly: true,
}))
defer server.Close()
@ -682,6 +675,7 @@ func TestNotSupportedVersions(t *testing.T) {
// setup
server := MustNewServer(t, &App{
DB: &gorm.DB{},
Clock: clock.NewMock(),
})
defer server.Close()

View file

@ -26,9 +26,9 @@ import (
"os"
"strings"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/database"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stripe/stripe-go"
@ -138,8 +138,7 @@ func (a *App) createSub(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
tx := db.Begin()
tx := a.DB.Begin()
if err := tx.Model(&user).
Update(map[string]interface{}{
@ -431,8 +430,7 @@ func (a *App) updateStripeSource(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
tx := db.Begin()
tx := a.DB.Begin()
if err := tx.Model(&user).
Update(map[string]interface{}{
@ -532,7 +530,7 @@ func (a *App) stripeWebhook(w http.ResponseWriter, req *http.Request) {
return
}
operations.MarkUnsubscribed(subscription.Customer.ID)
operations.MarkUnsubscribed(a.DB, subscription.Customer.ID)
}
default:
{

View file

@ -23,6 +23,7 @@ import (
"os"
"testing"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
@ -30,6 +31,7 @@ import (
// with the given app paratmers
func MustNewServer(t *testing.T, app *App) *httptest.Server {
app.WebURL = os.Getenv("WebURL")
app.DB = testutils.DB
r, err := NewRouter(app)
if err != nil {

View file

@ -23,11 +23,12 @@ import (
"net/http"
"time"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
@ -38,8 +39,6 @@ type updateProfilePayload struct {
// updateProfile updates user
func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@ -60,13 +59,13 @@ func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
err = db.Where("user_id = ?", user.ID).First(&account).Error
err = a.DB.Where("user_id = ?", user.ID).First(&account).Error
if err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
tx := db.Begin()
tx := a.DB.Begin()
if err := tx.Save(&user).Error; err != nil {
tx.Rollback()
HandleError(w, "saving user", err, http.StatusInternalServerError)
@ -87,7 +86,7 @@ func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
tx.Commit()
respondWithSession(w, user.ID, http.StatusOK)
respondWithSession(a.DB, w, user.ID, http.StatusOK)
}
type updateEmailPayload struct {
@ -97,9 +96,7 @@ type updateEmailPayload struct {
NewAuthKey string `json:"new_auth_key"`
}
func respondWithCalendar(w http.ResponseWriter, userID int) {
db := database.DBConn
func respondWithCalendar(db *gorm.DB, w http.ResponseWriter, userID int) {
rows, err := db.Table("notes").Select("COUNT(id), date(to_timestamp(added_on/1000000000)) AS added_date").
Where("user_id = ?", userID).
Group("added_date").
@ -132,22 +129,10 @@ func (a *App) getCalendar(w http.ResponseWriter, r *http.Request) {
return
}
respondWithCalendar(w, user.ID)
}
func (a *App) getDemoCalendar(w http.ResponseWriter, r *http.Request) {
userID, err := helpers.GetDemoUserID()
if err != nil {
HandleError(w, "finding demo user", err, http.StatusInternalServerError)
return
}
respondWithCalendar(w, userID)
respondWithCalendar(a.DB, w, user.ID)
}
func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@ -155,7 +140,7 @@ func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
err := db.Where("user_id = ?", user.ID).First(&account).Error
err := a.DB.Where("user_id = ?", user.ID).First(&account).Error
if err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
@ -182,7 +167,7 @@ func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
Type: database.TokenTypeEmailVerification,
}
if err := db.Save(&token).Error; err != nil {
if err := a.DB.Save(&token).Error; err != nil {
HandleError(w, "saving token", err, http.StatusInternalServerError)
return
}
@ -212,8 +197,6 @@ type verifyEmailPayload struct {
}
func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
var params verifyEmailPayload
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
@ -221,7 +204,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
}
var token database.Token
if err := db.
if err := a.DB.
Where("value = ? AND type = ?", params.Token, database.TokenTypeEmailVerification).
First(&token).Error; err != nil {
http.Error(w, "invalid token", http.StatusBadRequest)
@ -240,7 +223,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
if err := db.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
if err := a.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
HandleError(w, "finding account", err, http.StatusInternalServerError)
return
}
@ -249,7 +232,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
return
}
tx := db.Begin()
tx := a.DB.Begin()
account.EmailVerified = true
if err := tx.Save(&account).Error; err != nil {
tx.Rollback()
@ -264,7 +247,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
tx.Commit()
var user database.User
if err := db.Where("id = ?", token.UserID).First(&user).Error; err != nil {
if err := a.DB.Where("id = ?", token.UserID).First(&user).Error; err != nil {
HandleError(w, "finding user", err, http.StatusInternalServerError)
return
}
@ -278,8 +261,6 @@ type updateEmailPreferencePayload struct {
}
func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@ -293,12 +274,12 @@ func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
}
var frequency database.EmailPreference
if err := db.Where(database.EmailPreference{UserID: user.ID}).FirstOrCreate(&frequency).Error; err != nil {
if err := a.DB.Where(database.EmailPreference{UserID: user.ID}).FirstOrCreate(&frequency).Error; err != nil {
HandleError(w, "finding frequency", err, http.StatusInternalServerError)
return
}
tx := db.Begin()
tx := a.DB.Begin()
frequency.DigestWeekly = params.DigestWeekly
if err := tx.Save(&frequency).Error; err != nil {
@ -323,8 +304,6 @@ func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
}
func (a *App) getEmailPreference(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@ -332,7 +311,7 @@ func (a *App) getEmailPreference(w http.ResponseWriter, r *http.Request) {
}
var pref database.EmailPreference
if err := db.Where(database.EmailPreference{UserID: user.ID}).First(&pref).Error; err != nil {
if err := a.DB.Where(database.EmailPreference{UserID: user.ID}).First(&pref).Error; err != nil {
HandleError(w, "finding pref", err, http.StatusInternalServerError)
return
}
@ -347,8 +326,6 @@ type updatePasswordPayload struct {
}
func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
@ -366,7 +343,7 @@ func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
HandleError(w, "getting user", nil, http.StatusInternalServerError)
return
}
@ -391,7 +368,7 @@ func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
return
}
if err := db.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
if err := a.DB.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
http.Error(w, errors.Wrap(err, "updating password").Error(), http.StatusInternalServerError)
return
}

View file

@ -36,18 +36,14 @@ import (
"golang.org/x/crypto/bcrypt"
)
func init() {
testutils.InitTestDB()
}
func TestUpdatePassword(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -64,18 +60,19 @@ func TestUpdatePassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismsatch")
var account database.Account
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
assert.Equal(t, passwordErr, nil, "Password mismatch")
})
t.Run("old password mismatch", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -92,16 +89,17 @@ func TestUpdatePassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
var account database.Account
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
})
t.Run("password too short", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -118,15 +116,15 @@ func TestUpdatePassword(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
var account database.Account
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
})
}
func TestCreateVerificationToken(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
@ -137,6 +135,7 @@ func TestCreateVerificationToken(t *testing.T) {
mailer.InitTemplates(&templatePath)
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -154,9 +153,9 @@ func TestCreateVerificationToken(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified should not have been updated")
assert.NotEqual(t, token.Value, "", "token Value mismatch")
@ -165,11 +164,12 @@ func TestCreateVerificationToken(t *testing.T) {
})
t.Run("already verified", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -177,7 +177,7 @@ func TestCreateVerificationToken(t *testing.T) {
user := testutils.SetupUserData()
a := testutils.SetupAccountData(user, "alice@example.com", "pass1234")
a.EmailVerified = true
testutils.MustExec(t, db.Save(&a), "preparing account")
testutils.MustExec(t, testutils.DB.Save(&a), "preparing account")
// Execute
req := testutils.MakeReq(server, "POST", "/verification-token", "")
@ -188,8 +188,8 @@ func TestCreateVerificationToken(t *testing.T) {
var account database.Account
var tokenCount int
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified should not have been updated")
assert.Equal(t, tokenCount, 0, "token count mismatch")
@ -198,11 +198,12 @@ func TestCreateVerificationToken(t *testing.T) {
func TestVerifyEmail(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -214,7 +215,7 @@ func TestVerifyEmail(t *testing.T) {
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"token": "someTokenValue"}`
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
@ -228,9 +229,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified mismatch")
assert.NotEqual(t, token.Value, "", "token value should not have been updated")
@ -239,11 +240,12 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("used token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -258,7 +260,7 @@ func TestVerifyEmail(t *testing.T) {
Value: "someTokenValue",
UsedAt: &usedAt,
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"token": "someTokenValue"}`
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
@ -272,9 +274,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified mismatch")
assert.NotEqual(t, token.UsedAt, nil, "token used_at mismatch")
@ -283,11 +285,12 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("expired token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -300,8 +303,8 @@ func TestVerifyEmail(t *testing.T) {
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at")
dat := `{"token": "someTokenValue"}`
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
@ -315,9 +318,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified mismatch")
assert.Equal(t, tokenCount, 1, "token count mismatch")
@ -325,11 +328,12 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("already verified", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -337,14 +341,14 @@ func TestVerifyEmail(t *testing.T) {
user := testutils.SetupUserData()
a := testutils.SetupAccountData(user, "alice@example.com", "oldpass1234")
a.EmailVerified = true
testutils.MustExec(t, db.Save(&a), "preparing account")
testutils.MustExec(t, testutils.DB.Save(&a), "preparing account")
tok := database.Token{
UserID: user.ID,
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"token": "someTokenValue"}`
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
@ -358,9 +362,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified mismatch")
assert.Equal(t, tokenCount, 1, "token count mismatch")
@ -370,11 +374,12 @@ func TestVerifyEmail(t *testing.T) {
func TestUpdateEmail(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -382,7 +387,7 @@ func TestUpdateEmail(t *testing.T) {
u := testutils.SetupUserData()
a := testutils.SetupAccountData(u, "alice@example.com", "pass1234")
a.EmailVerified = true
testutils.MustExec(t, db.Save(&a), "updating email_verified")
testutils.MustExec(t, testutils.DB.Save(&a), "updating email_verified")
// Execute
dat := `{"email": "alice-new@example.com"}`
@ -394,8 +399,8 @@ func TestUpdateEmail(t *testing.T) {
var user database.User
var account database.Account
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, account.Email.String, "alice-new@example.com", "email mismatch")
assert.Equal(t, account.EmailVerified, false, "EmailVerified mismatch")
@ -404,11 +409,12 @@ func TestUpdateEmail(t *testing.T) {
func TestUpdateEmailPreference(t *testing.T) {
t.Run("with login", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -425,16 +431,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var preference database.EmailPreference
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding account")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding account")
assert.Equal(t, preference.DigestWeekly, true, "preference mismatch")
})
t.Run("with an unused token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -446,7 +453,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Type: database.TokenTypeEmailPreference,
Value: "someTokenValue",
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
// Execute
dat := `{"digest_weekly": true}`
@ -460,9 +467,9 @@ func TestUpdateEmailPreference(t *testing.T) {
var preference database.EmailPreference
var preferenceCount int
var token database.Token
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
testutils.MustExec(t, db.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
testutils.MustExec(t, db.Where("id = ?", tok.ID).First(&token), "failed to find token")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
testutils.MustExec(t, testutils.DB.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
testutils.MustExec(t, testutils.DB.Where("id = ?", tok.ID).First(&token), "failed to find token")
assert.Equal(t, preferenceCount, 1, "preference count mismatch")
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
@ -470,11 +477,12 @@ func TestUpdateEmailPreference(t *testing.T) {
})
t.Run("with nonexistent token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -486,7 +494,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Type: database.TokenTypeEmailPreference,
Value: "someTokenValue",
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"digest_weekly": false}`
url := fmt.Sprintf("/account/email-preference?token=%s", "someNonexistentToken")
@ -499,16 +507,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var preference database.EmailPreference
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
})
t.Run("with expired token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -523,7 +532,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Value: "someTokenValue",
UsedAt: &usedAt,
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
// Execute
dat := `{"digest_weekly": false}`
@ -535,16 +544,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var preference database.EmailPreference
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
})
t.Run("with a used but unexpired token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -558,7 +568,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Value: "someTokenValue",
UsedAt: &usedAt,
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
dat := `{"digest_weekly": false}`
url := fmt.Sprintf("/account/email-preference?token=%s", "someTokenValue")
@ -571,16 +581,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var preference database.EmailPreference
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, false, "DigestWeekly mismatch")
})
t.Run("no user and no token", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -597,16 +608,17 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var preference database.EmailPreference
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
})
t.Run("create a record if not exists", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -617,7 +629,7 @@ func TestUpdateEmailPreference(t *testing.T) {
Type: database.TokenTypeEmailPreference,
Value: "someTokenValue",
}
testutils.MustExec(t, db.Save(&tok), "preparing token")
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
// Execute
dat := `{"digest_weekly": false}`
@ -629,20 +641,21 @@ func TestUpdateEmailPreference(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var preferenceCount int
testutils.MustExec(t, db.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
testutils.MustExec(t, testutils.DB.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
assert.Equal(t, preferenceCount, 1, "preference count mismatch")
var preference database.EmailPreference
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
assert.Equal(t, preference.DigestWeekly, false, "email mismatch")
})
}
func TestGetEmailPreference(t *testing.T) {
defer testutils.ClearData()
defer testutils.ClearData()
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()

View file

@ -23,8 +23,9 @@ import (
"net/http"
"time"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
@ -63,9 +64,7 @@ func unsetSessionCookie(w http.ResponseWriter) {
http.SetCookie(w, &cookie)
}
func touchLastLoginAt(user database.User) error {
db := database.DBConn
func touchLastLoginAt(db *gorm.DB, user database.User) error {
t := time.Now()
if err := db.Model(&user).Update(database.User{LastLoginAt: &t}).Error; err != nil {
return errors.Wrap(err, "updating last_login_at")
@ -80,8 +79,6 @@ type signinPayload struct {
}
func (a *App) signin(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
var params signinPayload
err := json.NewDecoder(r.Body).Decode(&params)
if err != nil {
@ -94,7 +91,7 @@ func (a *App) signin(w http.ResponseWriter, r *http.Request) {
}
var account database.Account
conn := db.Where("email = ?", params.Email).First(&account)
conn := a.DB.Where("email = ?", params.Email).First(&account)
if conn.RecordNotFound() {
http.Error(w, ErrLoginFailure.Error(), http.StatusUnauthorized)
return
@ -111,19 +108,19 @@ func (a *App) signin(w http.ResponseWriter, r *http.Request) {
}
var user database.User
err = db.Where("id = ?", account.UserID).First(&user).Error
err = a.DB.Where("id = ?", account.UserID).First(&user).Error
if err != nil {
HandleError(w, "finding user", err, http.StatusInternalServerError)
return
}
err = operations.TouchLastLoginAt(user, db)
err = operations.TouchLastLoginAt(user, a.DB)
if err != nil {
http.Error(w, errors.Wrap(err, "touching login timestamp").Error(), http.StatusInternalServerError)
return
}
respondWithSession(w, account.UserID, http.StatusOK)
respondWithSession(a.DB, w, account.UserID, http.StatusOK)
}
func (a *App) signoutOptions(w http.ResponseWriter, r *http.Request) {
@ -143,7 +140,7 @@ func (a *App) signout(w http.ResponseWriter, r *http.Request) {
return
}
err = operations.DeleteSession(database.DBConn, key)
err = operations.DeleteSession(a.DB, key)
if err != nil {
HandleError(w, "deleting session", nil, http.StatusInternalServerError)
return
@ -182,8 +179,6 @@ func parseRegisterPaylaod(r *http.Request) (registerPayload, bool) {
}
func (a *App) register(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
params, ok := parseRegisterPaylaod(r)
if !ok {
http.Error(w, "invalid payload", http.StatusBadRequest)
@ -191,7 +186,7 @@ func (a *App) register(w http.ResponseWriter, r *http.Request) {
}
var count int
if err := db.Model(database.Account{}).Where("email = ?", params.Email).Count(&count).Error; err != nil {
if err := a.DB.Model(database.Account{}).Where("email = ?", params.Email).Count(&count).Error; err != nil {
HandleError(w, "checking duplicate user", err, http.StatusInternalServerError)
return
}
@ -200,20 +195,18 @@ func (a *App) register(w http.ResponseWriter, r *http.Request) {
return
}
user, err := operations.CreateUser(params.Email, params.Password)
user, err := operations.CreateUser(a.DB, params.Email, params.Password)
if err != nil {
HandleError(w, "creating user", err, http.StatusInternalServerError)
return
}
respondWithSession(w, user.ID, http.StatusCreated)
respondWithSession(a.DB, w, user.ID, http.StatusCreated)
}
// respondWithSession makes a HTTP response with the session from the user with the given userID.
// It sets the HTTP-Only cookie for browser clients and also sends a JSON response for non-browser clients.
func respondWithSession(w http.ResponseWriter, userID int, statusCode int) {
db := database.DBConn
func respondWithSession(db *gorm.DB, w http.ResponseWriter, userID int, statusCode int) {
session, err := operations.CreateSession(db, userID)
if err != nil {
HandleError(w, "creating session", nil, http.StatusBadRequest)

View file

@ -22,26 +22,17 @@ import (
"encoding/json"
"fmt"
"net/http"
"os"
"testing"
"time"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
)
func init() {
testutils.InitTestDB()
templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
mailer.InitTemplates(&templatePath)
}
func assertSessionResp(t *testing.T, res *http.Response) {
// after register, should sign in user
var got SessionResponse
@ -51,9 +42,8 @@ func assertSessionResp(t *testing.T, res *http.Response) {
var sessionCount int
var session database.Session
db := database.DBConn
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, db.First(&session), "getting session")
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, testutils.DB.First(&session), "getting session")
assert.Equal(t, sessionCount, 1, "sessionCount mismatch")
assert.Equal(t, got.Key, session.Key, "session Key mismatch")
@ -87,11 +77,12 @@ func TestRegister(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("register %s %s", tc.email, tc.password), func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -106,20 +97,20 @@ func TestRegister(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
var account database.Account
testutils.MustExec(t, db.Where("email = ?", tc.email).First(&account), "finding account")
testutils.MustExec(t, testutils.DB.Where("email = ?", tc.email).First(&account), "finding account")
assert.Equal(t, account.Email.String, tc.email, "Email mismatch")
assert.NotEqual(t, account.UserID, 0, "UserID mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password))
assert.Equal(t, passwordErr, nil, "Password mismatch")
var user database.User
testutils.MustExec(t, db.Where("id = ?", account.UserID).First(&user), "finding user")
testutils.MustExec(t, testutils.DB.Where("id = ?", account.UserID).First(&user), "finding user")
assert.Equal(t, user.Cloud, false, "Cloud mismatch")
assert.Equal(t, user.StripeCustomerID, "", "StripeCustomerID mismatch")
assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch")
var repetitionRuleCount int
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Where("user_id = ?", account.UserID).Count(&repetitionRuleCount), "counting repetition rules")
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Where("user_id = ?", account.UserID).Count(&repetitionRuleCount), "counting repetition rules")
assert.Equal(t, repetitionRuleCount, 1, "repetitionRuleCount mismatch")
// after register, should sign in user
@ -130,11 +121,12 @@ func TestRegister(t *testing.T) {
func TestRegisterMissingParams(t *testing.T) {
t.Run("missing email", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -149,19 +141,20 @@ func TestRegisterMissingParams(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, 0, "accountCount mismatch")
assert.Equal(t, userCount, 0, "userCount mismatch")
})
t.Run("missing password", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -176,8 +169,8 @@ func TestRegisterMissingParams(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, 0, "accountCount mismatch")
assert.Equal(t, userCount, 0, "userCount mismatch")
@ -185,11 +178,12 @@ func TestRegisterMissingParams(t *testing.T) {
}
func TestRegisterDuplicateEmail(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -207,12 +201,12 @@ func TestRegisterDuplicateEmail(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch")
var accountCount, userCount, verificationTokenCount int
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, db.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
var user database.User
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
assert.Equal(t, accountCount, 1, "account count mismatch")
assert.Equal(t, userCount, 1, "user count mismatch")
@ -222,11 +216,12 @@ func TestRegisterDuplicateEmail(t *testing.T) {
func TestSignIn(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -244,7 +239,7 @@ func TestSignIn(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "")
var user database.User
testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
assert.NotEqual(t, user.LastLoginAt, nil, "LastLoginAt mismatch")
// after register, should sign in user
@ -252,11 +247,12 @@ func TestSignIn(t *testing.T) {
})
t.Run("wrong password", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -274,20 +270,21 @@ func TestSignIn(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var user database.User
testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
var sessionCount int
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
})
t.Run("wrong email", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -305,20 +302,21 @@ func TestSignIn(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var user database.User
testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
assert.DeepEqual(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
var sessionCount int
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
})
t.Run("nonexistent email", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -333,14 +331,14 @@ func TestSignIn(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var sessionCount int
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
})
}
func TestSignout(t *testing.T) {
t.Run("authenticated", func(t *testing.T) {
db := database.DBConn
defer testutils.ClearData()
aliceUser := testutils.SetupUserData()
@ -352,16 +350,17 @@ func TestSignout(t *testing.T) {
UserID: aliceUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session1), "preparing session1")
testutils.MustExec(t, testutils.DB.Save(&session1), "preparing session1")
session2 := database.Session{
Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=",
UserID: anotherUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session2), "preparing session2")
testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session2")
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -376,8 +375,8 @@ func TestSignout(t *testing.T) {
var sessionCount int
var s2 database.Session
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2")
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2")
assert.Equal(t, sessionCount, 1, "sessionCount mismatch")
@ -391,7 +390,7 @@ func TestSignout(t *testing.T) {
})
t.Run("unauthenticated", func(t *testing.T) {
db := database.DBConn
defer testutils.ClearData()
aliceUser := testutils.SetupUserData()
@ -403,16 +402,17 @@ func TestSignout(t *testing.T) {
UserID: aliceUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session1), "preparing session1")
testutils.MustExec(t, testutils.DB.Save(&session1), "preparing session1")
session2 := database.Session{
Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=",
UserID: anotherUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
testutils.MustExec(t, db.Save(&session2), "preparing session2")
testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session2")
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
})
defer server.Close()
@ -426,9 +426,9 @@ func TestSignout(t *testing.T) {
var sessionCount int
var postSession1, postSession2 database.Session
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, db.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1")
testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2")
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
testutils.MustExec(t, testutils.DB.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1")
testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2")
// two existing sessions should remain
assert.Equal(t, sessionCount, 2, "sessionCount mismatch")

View file

@ -24,11 +24,12 @@ import (
"net/http"
"net/url"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@ -69,10 +70,8 @@ func (a *App) CreateBook(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
var bookCount int
err = db.Model(database.Book{}).
err = a.DB.Model(database.Book{}).
Where("user_id = ? AND label = ?", user.ID, params.Name).
Count(&bookCount).Error
if err != nil {
@ -84,7 +83,7 @@ func (a *App) CreateBook(w http.ResponseWriter, r *http.Request) {
return
}
book, err := operations.CreateBook(user, a.Clock, params.Name)
book, err := operations.CreateBook(a.DB, user, a.Clock, params.Name)
if err != nil {
HandleError(w, "inserting book", err, http.StatusInternalServerError)
}
@ -100,9 +99,7 @@ func (a *App) BooksOptions(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Version")
}
func respondWithBooks(userID int, query url.Values, w http.ResponseWriter) {
db := database.DBConn
func respondWithBooks(db *gorm.DB, userID int, query url.Values, w http.ResponseWriter) {
var books []database.Book
conn := db.Where("user_id = ? AND NOT deleted", userID).Order("label ASC")
name := query.Get("name")
@ -132,19 +129,6 @@ func respondWithBooks(userID int, query url.Values, w http.ResponseWriter) {
respondJSON(w, http.StatusOK, presentedBooks)
}
// GetDemoBooks returns books for demo
func (a *App) GetDemoBooks(w http.ResponseWriter, r *http.Request) {
demoUserID, err := helpers.GetDemoUserID()
if err != nil {
HandleError(w, "finding demo user", err, http.StatusInternalServerError)
return
}
query := r.URL.Query()
respondWithBooks(demoUserID, query, w)
}
// GetBooks returns books for the user
func (a *App) GetBooks(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
@ -154,7 +138,7 @@ func (a *App) GetBooks(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
respondWithBooks(user.ID, query, w)
respondWithBooks(a.DB, user.ID, query, w)
}
// GetBook returns a book for the user
@ -164,13 +148,11 @@ func (a *App) GetBook(w http.ResponseWriter, r *http.Request) {
return
}
db := database.DBConn
vars := mux.Vars(r)
bookUUID := vars["bookUUID"]
var book database.Book
conn := db.Where("uuid = ? AND user_id = ?", bookUUID, user.ID).First(&book)
conn := a.DB.Where("uuid = ? AND user_id = ?", bookUUID, user.ID).First(&book)
if conn.RecordNotFound() {
w.WriteHeader(http.StatusNotFound)
@ -204,8 +186,7 @@ func (a *App) UpdateBook(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
uuid := vars["bookUUID"]
db := database.DBConn
tx := db.Begin()
tx := a.DB.Begin()
var book database.Book
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, uuid).First(&book).Error; err != nil {
@ -250,8 +231,7 @@ func (a *App) DeleteBook(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
uuid := vars["bookUUID"]
db := database.DBConn
tx := db.Begin()
tx := a.DB.Begin()
var book database.Book
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, uuid).First(&book).Error; err != nil {

View file

@ -26,23 +26,20 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
}
func TestGetBooks(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
@ -55,28 +52,28 @@ func TestGetBooks(t *testing.T) {
USN: 1123,
Deleted: false,
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
USN: 1125,
Deleted: false,
}
testutils.MustExec(t, db.Save(&b2), "preparing b2")
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: anotherUser.ID,
Label: "css",
USN: 1128,
Deleted: false,
}
testutils.MustExec(t, db.Save(&b3), "preparing b3")
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
b4 := database.Book{
UserID: user.ID,
Label: "",
USN: 1129,
Deleted: true,
}
testutils.MustExec(t, db.Save(&b4), "preparing b4")
testutils.MustExec(t, testutils.DB.Save(&b4), "preparing b4")
// Execute
req := testutils.MakeReq(server, "GET", "/v3/books", "")
@ -91,9 +88,9 @@ func TestGetBooks(t *testing.T) {
}
var b1Record, b2Record database.Book
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
expected := []presenters.Book{
{
@ -116,12 +113,13 @@ func TestGetBooks(t *testing.T) {
}
func TestGetBooksByName(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
@ -133,17 +131,17 @@ func TestGetBooksByName(t *testing.T) {
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, db.Save(&b2), "preparing b2")
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: anotherUser.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b3), "preparing b3")
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
// Execute
res := testutils.HTTPAuthDo(t, req, user)
@ -157,7 +155,7 @@ func TestGetBooksByName(t *testing.T) {
}
var b1Record database.Book
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
expected := []presenters.Book{
{
@ -201,39 +199,40 @@ func TestDeleteBook(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 58), "preparing user max_usn")
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
b1 := database.Book{
UserID: user.ID,
Label: "js",
USN: 1,
}
testutils.MustExec(t, db.Save(&b1), "preparing a book data")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing a book data")
b2 := database.Book{
UserID: user.ID,
Label: tc.label,
USN: 2,
Deleted: tc.deleted,
}
testutils.MustExec(t, db.Save(&b2), "preparing a book data")
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing a book data")
b3 := database.Book{
UserID: anotherUser.ID,
Label: "linux",
USN: 3,
}
testutils.MustExec(t, db.Save(&b3), "preparing a book data")
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing a book data")
var n2Body string
if !tc.deleted {
@ -250,7 +249,7 @@ func TestDeleteBook(t *testing.T) {
Body: "n1 content",
USN: 4,
}
testutils.MustExec(t, db.Save(&n1), "preparing a note data")
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing a note data")
n2 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
@ -258,7 +257,7 @@ func TestDeleteBook(t *testing.T) {
USN: 5,
Deleted: tc.deleted,
}
testutils.MustExec(t, db.Save(&n2), "preparing a note data")
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing a note data")
n3 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
@ -266,7 +265,7 @@ func TestDeleteBook(t *testing.T) {
USN: 6,
Deleted: tc.deleted,
}
testutils.MustExec(t, db.Save(&n3), "preparing a note data")
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing a note data")
n4 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
@ -274,14 +273,14 @@ func TestDeleteBook(t *testing.T) {
USN: 7,
Deleted: true,
}
testutils.MustExec(t, db.Save(&n4), "preparing a note data")
testutils.MustExec(t, testutils.DB.Save(&n4), "preparing a note data")
n5 := database.Note{
UserID: anotherUser.ID,
BookUUID: b3.UUID,
Body: "n5 content",
USN: 8,
}
testutils.MustExec(t, db.Save(&n5), "preparing a note data")
testutils.MustExec(t, testutils.DB.Save(&n5), "preparing a note data")
endpoint := fmt.Sprintf("/v3/books/%s", b2.UUID)
req := testutils.MakeReq(server, "DELETE", endpoint, "")
@ -299,17 +298,17 @@ func TestDeleteBook(t *testing.T) {
var userRecord database.User
var bookCount, noteCount int
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, db.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
testutils.MustExec(t, db.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
testutils.MustExec(t, db.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
testutils.MustExec(t, db.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
testutils.MustExec(t, db.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
testutils.MustExec(t, db.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, testutils.DB.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
testutils.MustExec(t, testutils.DB.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
testutils.MustExec(t, testutils.DB.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
testutils.MustExec(t, testutils.DB.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
testutils.MustExec(t, testutils.DB.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
testutils.MustExec(t, testutils.DB.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equal(t, bookCount, 3, "book count mismatch")
assert.Equal(t, noteCount, 5, "note count mismatch")
@ -351,17 +350,18 @@ func TestDeleteBook(t *testing.T) {
}
func TestCreateBook(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
req := testutils.MakeReq(server, "POST", "/v3/books", `{"name": "js"}`)
req.Header.Set("Version", "0.1.1")
@ -376,10 +376,10 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var bookCount, noteCount int
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
maxUSN := 102
@ -410,24 +410,25 @@ func TestCreateBook(t *testing.T) {
}
func TestCreateBookDuplicate(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UserID: user.ID,
Label: "js",
USN: 58,
}
testutils.MustExec(t, db.Save(&b1), "preparing book data")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book data")
// Execute
req := testutils.MakeReq(server, "POST", "/v3/books", `{"name": "js"}`)
@ -439,10 +440,10 @@ func TestCreateBookDuplicate(t *testing.T) {
var bookRecord database.Book
var bookCount, noteCount int
var userRecord database.User
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 1, "book count mismatch")
assert.Equalf(t, noteCount, 0, "note count mismatch")
@ -489,17 +490,18 @@ func TestUpdateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: tc.bookUUID,
@ -507,15 +509,15 @@ func TestUpdateBook(t *testing.T) {
Label: tc.bookLabel,
Deleted: tc.bookDeleted,
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: b2UUID,
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b2), "preparing b2")
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
// Execute
// Executdb,e
endpoint := fmt.Sprintf("/v3/books/%s", tc.bookUUID)
req := testutils.MakeReq(server, "PATCH", endpoint, tc.payload)
res := testutils.HTTPAuthDo(t, req, user)
@ -526,10 +528,10 @@ func TestUpdateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var noteCount, bookCount int
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 2, "book count mismatch")
assert.Equalf(t, noteCount, 0, "note count mismatch")

View file

@ -23,10 +23,10 @@ import (
"fmt"
"net/http"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/database"
"github.com/gorilla/mux"
"github.com/pkg/errors"
)
@ -48,7 +48,6 @@ func validateUpdateNotePayload(p updateNotePayload) bool {
// UpdateNote updates note
func (a *App) UpdateNote(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
vars := mux.Vars(r)
noteUUID := vars["noteUUID"]
@ -71,12 +70,12 @@ func (a *App) UpdateNote(w http.ResponseWriter, r *http.Request) {
}
var note database.Note
if err := db.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).First(&note).Error; err != nil {
if err := a.DB.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).First(&note).Error; err != nil {
HandleError(w, "finding note", err, http.StatusInternalServerError)
return
}
tx := db.Begin()
tx := a.DB.Begin()
note, err = operations.UpdateNote(tx, user, a.Clock, note, &operations.UpdateNoteParams{
BookUUID: params.BookUUID,
@ -116,8 +115,6 @@ type deleteNoteResp struct {
// DeleteNote removes note
func (a *App) DeleteNote(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
vars := mux.Vars(r)
noteUUID := vars["noteUUID"]
@ -128,12 +125,12 @@ func (a *App) DeleteNote(w http.ResponseWriter, r *http.Request) {
}
var note database.Note
if err := db.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).Preload("Book").First(&note).Error; err != nil {
if err := a.DB.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).Preload("Book").First(&note).Error; err != nil {
HandleError(w, "finding note", err, http.StatusInternalServerError)
return
}
tx := db.Begin()
tx := a.DB.Begin()
n, err := operations.DeleteNote(tx, user, note)
if err != nil {
@ -193,13 +190,12 @@ func (a *App) CreateNote(w http.ResponseWriter, r *http.Request) {
}
var book database.Book
db := database.DBConn
if err := db.Where("uuid = ? AND user_id = ?", params.BookUUID, user.ID).First(&book).Error; err != nil {
if err := a.DB.Where("uuid = ? AND user_id = ?", params.BookUUID, user.ID).First(&book).Error; err != nil {
HandleError(w, "finding book", err, http.StatusInternalServerError)
return
}
note, err := operations.CreateNote(user, a.Clock, params.BookUUID, params.Content, params.AddedOn, params.EditedOn, false)
note, err := operations.CreateNote(a.DB, user, a.Clock, params.BookUUID, params.Content, params.AddedOn, params.EditedOn, false)
if err != nil {
HandleError(w, "creating note", err, http.StatusInternalServerError)
return

View file

@ -29,29 +29,26 @@ import (
"github.com/dnote/dnote/pkg/server/testutils"
)
func init() {
testutils.InitTestDB()
}
func TestCreateNote(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UserID: user.ID,
Label: "js",
USN: 58,
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
// Execute
dat := fmt.Sprintf(`{"book_uuid": "%s", "content": "note content"}`, b1.UUID)
@ -65,11 +62,11 @@ func TestCreateNote(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var bookCount, noteCount int
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.First(&noteRecord), "finding note")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.First(&noteRecord), "finding note")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 1, "book count mismatch")
assert.Equalf(t, noteCount, 1, "note count mismatch")
@ -238,30 +235,31 @@ func TestUpdateNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: b1UUID,
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: b2UUID,
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b2), "preparing b2")
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
note := database.Note{
UserID: user.ID,
@ -271,7 +269,7 @@ func TestUpdateNote(t *testing.T) {
Deleted: tc.noteDeleted,
Public: tc.notePublic,
}
testutils.MustExec(t, db.Save(&note), "preparing note")
testutils.MustExec(t, testutils.DB.Save(&note), "preparing note")
// Execute
endpoint := fmt.Sprintf("/v3/notes/%s", note.UUID)
@ -285,11 +283,11 @@ func TestUpdateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
var noteCount, bookCount int
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(&noteRecord), "finding note")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(&noteRecord), "finding note")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 2, "book count mismatch")
assert.Equalf(t, noteCount, 1, "note count mismatch")
@ -333,24 +331,25 @@ func TestDeleteNote(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
// Setup
server := MustNewServer(t, &App{
Clock: clock.NewMock(),
Clock: clock.NewMock(),
})
defer server.Close()
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 981), "preparing user max_usn")
b1 := database.Book{
UUID: b1UUID,
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
note := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
@ -358,7 +357,7 @@ func TestDeleteNote(t *testing.T) {
Deleted: tc.deleted,
USN: tc.originalUSN,
}
testutils.MustExec(t, db.Save(&note), "preparing note")
testutils.MustExec(t, testutils.DB.Save(&note), "preparing note")
// Execute
endpoint := fmt.Sprintf("/v3/notes/%s", note.UUID)
@ -372,11 +371,11 @@ func TestDeleteNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
var bookCount, noteCount int
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(&noteRecord), "finding note")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(&noteRecord), "finding note")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, 1, "book count mismatch")
assert.Equalf(t, noteCount, 1, "note count mismatch")

View file

@ -26,8 +26,8 @@ import (
"strconv"
"time"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
)
@ -121,14 +121,12 @@ func (e *queryParamError) Error() string {
}
func (a *App) newFragment(userID, userMaxUSN, afterUSN, limit int) (SyncFragment, error) {
db := database.DBConn
var notes []database.Note
if err := db.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(&notes).Error; err != nil {
if err := a.DB.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(&notes).Error; err != nil {
return SyncFragment{}, nil
}
var books []database.Book
if err := db.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(&books).Error; err != nil {
if err := a.DB.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(&books).Error; err != nil {
return SyncFragment{}, nil
}

View file

@ -19,8 +19,8 @@
package helpers
import (
"github.com/dnote/dnote/pkg/server/database"
"github.com/google/uuid"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@ -29,8 +29,7 @@ const (
)
// GetDemoUserID returns ID of the demo user
func GetDemoUserID() (int, error) {
db := database.DBConn
func GetDemoUserID(db *gorm.DB) (int, error) {
result := struct {
UserID int

View file

@ -24,6 +24,7 @@ import (
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/job/repetition"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/robfig/cron"
)
@ -45,12 +46,12 @@ func checkEnvironment() error {
return nil
}
func schedule(ch chan error) {
func schedule(db *gorm.DB, ch chan error) {
cl := clock.New()
// Schedule jobs
c := cron.New()
scheduleJob(c, "* * * * *", func() { repetition.Do(cl) })
scheduleJob(c, "* * * * *", func() { repetition.Do(db, cl) })
c.Start()
ch <- nil
@ -60,13 +61,13 @@ func schedule(ch chan error) {
}
// Run starts the background tasks in a separate goroutine that runs forever
func Run() error {
func Run(db *gorm.DB) error {
if err := checkEnvironment(); err != nil {
return errors.Wrap(err, "checking environment variables")
}
ch := make(chan error)
go schedule(ch)
go schedule(db, ch)
if err := <-ch; err != nil {
return errors.Wrap(err, "scheduling jobs")
}

View file

@ -0,0 +1,17 @@
package repetition
import (
"os"
"testing"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestMain(m *testing.M) {
testutils.InitTestDB()
code := m.Run()
testutils.ClearData()
os.Exit(code)
}

View file

@ -31,20 +31,29 @@ import (
"github.com/pkg/errors"
)
// BuildEmailParams is the params for building an email
type BuildEmailParams struct {
Now time.Time
User database.User
EmailAddr string
Digest database.Digest
Rule database.RepetitionRule
}
// BuildEmail builds an email for the spaced repetition
func BuildEmail(now time.Time, user database.User, emailAddr string, digest database.Digest, rule database.RepetitionRule) (*mailer.Email, error) {
date := now.Format("Jan 02 2006")
subject := fmt.Sprintf("%s %s", rule.Title, date)
tok, err := mailer.GetToken(user, database.TokenTypeRepetition)
func BuildEmail(db *gorm.DB, p BuildEmailParams) (*mailer.Email, error) {
date := p.Now.Format("Jan 02 2006")
subject := fmt.Sprintf("%s %s", p.Rule.Title, date)
tok, err := mailer.GetToken(db, p.User, database.TokenTypeRepetition)
if err != nil {
return nil, errors.Wrap(err, "getting email frequency token")
}
t1 := now.AddDate(0, 0, -3).UnixNano()
t2 := now.AddDate(0, 0, -7).UnixNano()
t1 := p.Now.AddDate(0, 0, -3).UnixNano()
t2 := p.Now.AddDate(0, 0, -7).UnixNano()
noteInfos := []mailer.DigestNoteInfo{}
for _, note := range digest.Notes {
for _, note := range p.Digest.Notes {
var stage int
if note.AddedOn > t1 {
stage = 1
@ -60,7 +69,7 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
bookCount := 0
bookMap := map[string]bool{}
for _, n := range digest.Notes {
for _, n := range p.Digest.Notes {
if ok := bookMap[n.Book.Label]; !ok {
bookCount++
bookMap[n.Book.Label] = true
@ -71,14 +80,14 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
Subject: subject,
NoteInfo: noteInfos,
ActiveBookCount: bookCount,
ActiveNoteCount: len(digest.Notes),
ActiveNoteCount: len(p.Digest.Notes),
EmailSessionToken: tok.Value,
RuleUUID: rule.UUID,
RuleTitle: rule.Title,
RuleUUID: p.Rule.UUID,
RuleTitle: p.Rule.Title,
WebURL: os.Getenv("WebURL"),
}
email := mailer.NewEmail("noreply@getdnote.com", []string{emailAddr}, subject)
email := mailer.NewEmail("noreply@getdnote.com", []string{p.EmailAddr}, subject)
if err := email.ParseTemplate(mailer.EmailTypeWeeklyDigest, tmplData); err != nil {
return nil, err
}
@ -86,12 +95,11 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
return email, nil
}
func getEligibleRules(now time.Time) ([]database.RepetitionRule, error) {
func getEligibleRules(db *gorm.DB, now time.Time) ([]database.RepetitionRule, error) {
hour := now.Hour()
minute := now.Minute()
var ret []database.RepetitionRule
db := database.DBConn
if err := db.
Where("users.cloud AND repetition_rules.hour = ? AND repetition_rules.minute = ? AND repetition_rules.enabled", hour, minute).
Joins("INNER JOIN users ON users.id = repetition_rules.user_id").
@ -120,9 +128,7 @@ func build(tx *gorm.DB, rule database.RepetitionRule) (database.Digest, error) {
return digest, nil
}
func notify(now time.Time, user database.User, digest database.Digest, rule database.RepetitionRule) error {
db := database.DBConn
func notify(db *gorm.DB, now time.Time, user database.User, digest database.Digest, rule database.RepetitionRule) error {
var account database.Account
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
return errors.Wrap(err, "getting account")
@ -135,7 +141,13 @@ func notify(now time.Time, user database.User, digest database.Digest, rule data
return nil
}
email, err := BuildEmail(now, user, account.Email.String, digest, rule)
email, err := BuildEmail(db, BuildEmailParams{
Now: now,
User: user,
EmailAddr: account.Email.String,
Digest: digest,
Rule: rule,
})
if err != nil {
return errors.Wrap(err, "making email")
}
@ -185,12 +197,11 @@ func touchTimestamp(tx *gorm.DB, rule database.RepetitionRule, now time.Time) er
return nil
}
func process(now time.Time, rule database.RepetitionRule) error {
func process(db *gorm.DB, now time.Time, rule database.RepetitionRule) error {
log.WithFields(log.Fields{
"uuid": rule.UUID,
}).Info("processing repetition")
db := database.DBConn
tx := db.Begin()
if !checkCooldown(now, rule) {
@ -224,7 +235,7 @@ func process(now time.Time, rule database.RepetitionRule) error {
return errors.Wrap(err, "committing transaction")
}
if err := notify(now, user, digest, rule); err != nil {
if err := notify(db, now, user, digest, rule); err != nil {
return errors.Wrap(err, "notifying user")
}
@ -236,10 +247,10 @@ func process(now time.Time, rule database.RepetitionRule) error {
}
// Do creates spaced repetitions and delivers the results based on the rules
func Do(c clock.Clock) error {
func Do(db *gorm.DB, c clock.Clock) error {
now := c.Now().UTC()
rules, err := getEligibleRules(now)
rules, err := getEligibleRules(db, now)
if err != nil {
return errors.Wrap(err, "getting eligible repetition rules")
}
@ -251,7 +262,7 @@ func Do(c clock.Clock) error {
}).Info("processing rules")
for _, rule := range rules {
if err := process(now, rule); err != nil {
if err := process(db, now, rule); err != nil {
log.WithFields(log.Fields{
"rule uuid": rule.UUID,
}).ErrorWrap(err, "Could not process the repetition rule")

View file

@ -29,24 +29,18 @@ import (
"github.com/dnote/dnote/pkg/server/testutils"
)
func init() {
testutils.InitTestDB()
}
func assertLastActive(t *testing.T, ruleUUID string, lastActive int64) {
db := database.DBConn
var rule database.RepetitionRule
testutils.MustExec(t, db.Where("uuid = ?", ruleUUID).First(&rule), "finding rule1")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", ruleUUID).First(&rule), "finding rule1")
assert.Equal(t, rule.LastActive, lastActive, "LastActive mismatch")
}
func assertDigestCount(t *testing.T, rule database.RepetitionRule, expected int) {
db := database.DBConn
var digestCount int
testutils.MustExec(t, db.Model(&database.Digest{}).Where("rule_id = ? AND user_id = ?", rule.ID, rule.UserID).Count(&digestCount), "counting digest")
testutils.MustExec(t, testutils.DB.Model(&database.Digest{}).Where("rule_id = ? AND user_id = ?", rule.ID, rule.UserID).Count(&digestCount), "counting digest")
assert.Equal(t, digestCount, expected, "digest count mismatch")
}
@ -74,68 +68,67 @@ func TestDo(t *testing.T) {
},
}
db := database.DBConn
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
c := clock.NewMock()
// Test
// 1 day later
c.SetNow(time.Date(2009, time.November, 2, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(0))
assertDigestCount(t, r1, 0)
// 2 days later
c.SetNow(time.Date(2009, time.November, 3, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(0))
assertDigestCount(t, r1, 0)
// 3 days later - should be processed
c.SetNow(time.Date(2009, time.November, 4, 12, 1, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(0))
assertDigestCount(t, r1, 0)
c.SetNow(time.Date(2009, time.November, 4, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257336120000))
assertDigestCount(t, r1, 1)
c.SetNow(time.Date(2009, time.November, 4, 12, 3, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257336120000))
assertDigestCount(t, r1, 1)
// 4 day later
c.SetNow(time.Date(2009, time.November, 5, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257336120000))
assertDigestCount(t, r1, 1)
// 5 days later
c.SetNow(time.Date(2009, time.November, 6, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257336120000))
assertDigestCount(t, r1, 1)
// 6 days later - should be processed
c.SetNow(time.Date(2009, time.November, 7, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257595320000))
assertDigestCount(t, r1, 2)
// 7 days later
c.SetNow(time.Date(2009, time.November, 8, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257595320000))
assertDigestCount(t, r1, 2)
// 8 days later
c.SetNow(time.Date(2009, time.November, 9, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257595320000))
assertDigestCount(t, r1, 2)
// 9 days later - should be processed
c.SetNow(time.Date(2009, time.November, 10, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
assertLastActive(t, r1.UUID, int64(1257854520000))
assertDigestCount(t, r1, 3)
})
@ -177,15 +170,14 @@ func TestDo(t *testing.T) {
},
}
db := database.DBConn
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 10, 12, 2, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
var rule database.RepetitionRule
testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&rule), "finding rule1")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&rule), "finding rule1")
assert.Equal(t, rule.LastActive, time.Date(2009, time.November, 10, 12, 2, 0, 0, time.UTC).UnixNano()/int64(time.Millisecond), "LastActive mismsatch")
assert.Equal(t, rule.NextActive, time.Date(2009, time.November, 13, 12, 2, 0, 0, time.UTC).UnixNano()/int64(time.Millisecond), "NextActive mismsatch")
@ -216,13 +208,12 @@ func TestDo_Disabled(t *testing.T) {
},
}
db := database.DBConn
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 4, 12, 2, 0, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
// Test
assertLastActive(t, r1.UUID, int64(0))
@ -241,40 +232,39 @@ func TestDo_BalancedStrategy(t *testing.T) {
}
setup := func() testData {
db := database.DBConn
user := testutils.SetupUserData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, db.Save(&b2), "preparing b2")
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: user.ID,
Label: "golang",
}
testutils.MustExec(t, db.Save(&b3), "preparing b3")
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
n1 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
}
testutils.MustExec(t, db.Save(&n1), "preparing n1")
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
n2 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
}
testutils.MustExec(t, db.Save(&n2), "preparing n2")
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
n3 := database.Note{
UserID: user.ID,
BookUUID: b3.UUID,
}
testutils.MustExec(t, db.Save(&n3), "preparing n3")
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
return testData{
User: user,
@ -293,7 +283,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
// Set up
dat := setup()
db := database.DBConn
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
r1 := database.RepetitionRule{
@ -312,20 +301,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
UpdatedAt: t0,
},
}
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
// Test
assertLastActive(t, r1.UUID, int64(1257714000000))
assertDigestCount(t, r1, 1)
var repetition database.Digest
testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
sort.SliceStable(repetition.Notes, func(i, j int) bool {
n1 := repetition.Notes[i]
@ -335,9 +324,9 @@ func TestDo_BalancedStrategy(t *testing.T) {
})
var n1Record, n2Record, n3Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
testutils.MustExec(t, db.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
expected := []database.Note{n1Record, n2Record, n3Record}
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
})
@ -348,7 +337,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
// Set up
dat := setup()
db := database.DBConn
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
r1 := database.RepetitionRule{
@ -368,20 +356,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
UpdatedAt: t0,
},
}
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 1, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
// Test
assertLastActive(t, r1.UUID, int64(1257714000000))
assertDigestCount(t, r1, 1)
var repetition database.Digest
testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
sort.SliceStable(repetition.Notes, func(i, j int) bool {
n1 := repetition.Notes[i]
@ -391,8 +379,8 @@ func TestDo_BalancedStrategy(t *testing.T) {
})
var n2Record, n3Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
testutils.MustExec(t, db.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
expected := []database.Note{n2Record, n3Record}
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
})
@ -403,7 +391,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
// Set up
dat := setup()
db := database.DBConn
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
r1 := database.RepetitionRule{
@ -423,20 +410,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
UpdatedAt: t0,
},
}
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
// Execute
c := clock.NewMock()
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC))
Do(c)
Do(testutils.DB, c)
// Test
assertLastActive(t, r1.UUID, int64(1257714000000))
assertDigestCount(t, r1, 1)
var repetition database.Digest
testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
sort.SliceStable(repetition.Notes, func(i, j int) bool {
n1 := repetition.Notes[i]
@ -446,8 +433,8 @@ func TestDo_BalancedStrategy(t *testing.T) {
})
var n1Record, n2Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
expected := []database.Note{n1Record, n2Record}
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
})

View file

@ -27,8 +27,7 @@ import (
"github.com/pkg/errors"
)
func getRuleBookIDs(ruleID int) ([]int, error) {
db := database.DBConn
func getRuleBookIDs(db *gorm.DB, ruleID int) ([]int, error) {
var ret []int
if err := db.Table("repetition_rule_books").Select("book_id").Where("repetition_rule_id = ?", ruleID).Pluck("book_id", &ret).Error; err != nil {
return nil, errors.Wrap(err, "querying book_ids")
@ -37,11 +36,11 @@ func getRuleBookIDs(ruleID int) ([]int, error) {
return ret, nil
}
func applyBookDomain(noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB, error) {
func applyBookDomain(db *gorm.DB, noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB, error) {
ret := noteQuery
if rule.BookDomain != database.BookDomainAll {
bookIDs, err := getRuleBookIDs(rule.ID)
bookIDs, err := getRuleBookIDs(db, rule.ID)
if err != nil {
return nil, errors.Wrap(err, "getting book_ids")
}
@ -58,8 +57,8 @@ func applyBookDomain(noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB
return ret, nil
}
func getNotes(conn *gorm.DB, rule database.RepetitionRule, dst *[]database.Note) error {
c, err := applyBookDomain(conn, rule)
func getNotes(db, conn *gorm.DB, rule database.RepetitionRule, dst *[]database.Note) error {
c, err := applyBookDomain(db, conn, rule)
if err != nil {
return errors.Wrap(err, "building query for book threahold 1")
}
@ -79,16 +78,14 @@ func getBalancedNotes(db *gorm.DB, rule database.RepetitionRule) ([]database.Not
t2 := now.AddDate(0, 0, -7).UnixNano()
// Get notes into three buckets with different threshold values
var stage1 []database.Note
var stage2 []database.Note
var stage3 []database.Note
if err := getNotes(db.Where("notes.added_on > ?", t1), rule, &stage1); err != nil {
var stage1, stage2, stage3 []database.Note
if err := getNotes(db, db.Where("notes.added_on > ?", t1), rule, &stage1); err != nil {
return nil, errors.Wrap(err, "Failed to get notes with threshold 1")
}
if err := getNotes(db.Where("notes.added_on > ? AND notes.added_on < ?", t2, t1), rule, &stage2); err != nil {
if err := getNotes(db, db.Where("notes.added_on > ? AND notes.added_on < ?", t2, t1), rule, &stage2); err != nil {
return nil, errors.Wrap(err, "Failed to get notes with threshold 2")
}
if err := getNotes(db.Where("notes.added_on < ?", t2), rule, &stage3); err != nil {
if err := getNotes(db, db.Where("notes.added_on < ?", t2), rule, &stage3); err != nil {
return nil, errors.Wrap(err, "Failed to get notes with threshold 3")
}

View file

@ -34,45 +34,43 @@ func init() {
func TestApplyBookDomain(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
b2 := database.Book{
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, db.Save(&b2), "preparing b2")
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
b3 := database.Book{
UserID: user.ID,
Label: "golang",
}
testutils.MustExec(t, db.Save(&b3), "preparing b3")
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
n1 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
}
testutils.MustExec(t, db.Save(&n1), "preparing n1")
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
n2 := database.Note{
UserID: user.ID,
BookUUID: b2.UUID,
}
testutils.MustExec(t, db.Save(&n2), "preparing n2")
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
n3 := database.Note{
UserID: user.ID,
BookUUID: b3.UUID,
}
testutils.MustExec(t, db.Save(&n3), "preparing n3")
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
var n1Record, n2Record, n3Record database.Note
testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1")
testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2")
testutils.MustExec(t, db.Where("uuid = ?", n3.UUID).First(&n3Record), "finding n3")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n3.UUID).First(&n3Record), "finding n3")
t.Run("book domain all", func(t *testing.T) {
rule := database.RepetitionRule{
@ -80,7 +78,7 @@ func TestApplyBookDomain(t *testing.T) {
BookDomain: database.BookDomainAll,
}
conn, err := applyBookDomain(db, rule)
conn, err := applyBookDomain(testutils.DB, testutils.DB, rule)
if err != nil {
t.Fatal(errors.Wrap(err, "executing").Error())
}
@ -98,9 +96,9 @@ func TestApplyBookDomain(t *testing.T) {
BookDomain: database.BookDomainExluding,
Books: []database.Book{b1},
}
testutils.MustExec(t, db.Save(&rule), "preparing rule")
testutils.MustExec(t, testutils.DB.Save(&rule), "preparing rule")
conn, err := applyBookDomain(db.Debug(), rule)
conn, err := applyBookDomain(testutils.DB, testutils.DB, rule)
if err != nil {
t.Fatal(errors.Wrap(err, "executing").Error())
}

View file

@ -25,15 +25,17 @@ import (
"time"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/dbconn"
"github.com/dnote/dnote/pkg/server/job/repetition"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/jinzhu/gorm"
"github.com/joho/godotenv"
_ "github.com/lib/pq"
"github.com/pkg/errors"
)
func digestHandler(w http.ResponseWriter, r *http.Request) {
db := database.DBConn
func (c Context) digestHandler(w http.ResponseWriter, r *http.Request) {
db := c.DB
q := r.URL.Query()
digestUUID := q.Get("digest_uuid")
@ -61,7 +63,13 @@ func digestHandler(w http.ResponseWriter, r *http.Request) {
}
now := time.Now()
email, err := repetition.BuildEmail(now, user, "sung@getdnote.com", digest, rule)
email, err := repetition.BuildEmail(db, repetition.BuildEmailParams{
Now: now,
User: user,
EmailAddr: "sung@getdnote.com",
Digest: digest,
Rule: rule,
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -71,7 +79,7 @@ func digestHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(body))
}
func emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
func (c Context) emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
data := struct {
Subject string
Token string
@ -90,7 +98,7 @@ func emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(body))
}
func homeHandler(w http.ResponseWriter, r *http.Request) {
func (c Context) homeHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Email development server is running."))
}
@ -101,23 +109,29 @@ func init() {
}
}
// Context is a context holding global information
type Context struct {
DB *gorm.DB
}
func main() {
c := database.Config{
db := dbconn.Open(dbconn.Config{
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
}
database.Open(c)
defer database.Close()
})
defer db.Close()
mailer.InitTemplates(nil)
log.Println("Email template development server running on http://127.0.0.1:2300")
http.HandleFunc("/", homeHandler)
http.HandleFunc("/digest", digestHandler)
http.HandleFunc("/email-verification", emailVerificationHandler)
ctx := Context{DB: db}
http.HandleFunc("/", ctx.homeHandler)
http.HandleFunc("/digest", ctx.digestHandler)
http.HandleFunc("/email-verification", ctx.emailVerificationHandler)
log.Fatal(http.ListenAndServe(":2300", nil))
}

View file

@ -23,6 +23,7 @@ import (
"encoding/base64"
"github.com/dnote/dnote/pkg/server/database"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@ -39,9 +40,7 @@ func generateRandomToken(bits int) (string, error) {
// GetToken returns an token of the given kind for the user
// by first looking up any unused record and creating one if none exists.
func GetToken(user database.User, kind string) (database.Token, error) {
db := database.DBConn
func GetToken(db *gorm.DB, user database.User, kind string) (database.Token, error) {
var tok database.Token
conn := db.
Where("user_id = ? AND type =? AND used_at IS NULL", user.ID, kind).

View file

@ -27,10 +27,12 @@ import (
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/dbconn"
"github.com/dnote/dnote/pkg/server/handlers"
"github.com/dnote/dnote/pkg/server/job"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/web"
"github.com/jinzhu/gorm"
"github.com/gobuffalo/packr/v2"
"github.com/pkg/errors"
@ -64,49 +66,70 @@ func initContext() web.Context {
}
}
func initServer() (*http.ServeMux, error) {
apiRouter, err := handlers.NewRouter(&handlers.App{
Clock: clock.New(),
StripeAPIBackend: nil,
WebURL: os.Getenv("WebURL"),
})
func initServer(app handlers.App) (*http.ServeMux, error) {
apiRouter, err := handlers.NewRouter(&app)
if err != nil {
return nil, errors.Wrap(err, "initializing router")
}
ctx := initContext()
webCtx := initContext()
webHandlers := web.Init(webCtx)
mux := http.NewServeMux()
mux.Handle("/api/", http.StripPrefix("/api", apiRouter))
mux.Handle("/static/", web.GetStaticHandler(ctx.StaticFileSystem))
mux.HandleFunc("/service-worker.js", web.GetSWHandler(ctx.ServiceWorkerJs))
mux.HandleFunc("/robots.txt", web.GetRobotsHandler(ctx.RobotsTxt))
mux.HandleFunc("/", web.GetRootHandler(ctx.IndexHTML))
mux.Handle("/static/", webHandlers.GetStatic)
mux.HandleFunc("/service-worker.js", webHandlers.GetServiceWorker)
mux.HandleFunc("/robots.txt", webHandlers.GetRobots)
mux.HandleFunc("/", webHandlers.GetRoot)
return mux, nil
}
func startCmd() {
mailer.InitTemplates(nil)
func initDB() *gorm.DB {
var skipSSL bool
if os.Getenv("GO_ENV") != "PRODUCTION" || os.Getenv("DB_NOSSL") != "" {
skipSSL = true
} else {
skipSSL = false
}
database.Open(database.Config{
db := dbconn.Open(dbconn.Config{
SkipSSL: skipSSL,
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
})
database.InitSchema()
defer database.Close()
database.InitSchema(db)
if err := database.Migrate(); err != nil {
return db
}
func initApp(db *gorm.DB) handlers.App {
return handlers.App{
DB: db,
Clock: clock.New(),
StripeAPIBackend: nil,
WebURL: os.Getenv("WebURL"),
}
}
func startCmd() {
db := initDB()
defer db.Close()
app := initApp(db)
mailer.InitTemplates(nil)
if err := database.Migrate(app.DB); err != nil {
panic(errors.Wrap(err, "running migrations"))
}
if err := job.Run(); err != nil {
if err := job.Run(db); err != nil {
panic(errors.Wrap(err, "running job"))
}
srv, err := initServer()
srv, err := initServer(app)
if err != nil {
panic(errors.Wrap(err, "initializing server"))
}

View file

@ -20,15 +20,14 @@ package operations
import (
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
// CreateBook creates a book with the next usn and updates the user's max_usn
func CreateBook(user database.User, clock clock.Clock, name string) (database.Book, error) {
db := database.DBConn
func CreateBook(db *gorm.DB, user database.User, clock clock.Clock, name string) (database.Book, error) {
tx := db.Begin()
nextUSN, err := incrementUserUSN(tx, user.ID)

View file

@ -29,10 +29,6 @@ import (
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
}
func TestCreateBook(t *testing.T) {
testCases := []struct {
userUSN int
@ -59,17 +55,16 @@ func TestCreateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
c := clock.NewMock()
book, err := CreateBook(user, c, tc.label)
book, err := CreateBook(testutils.DB, user, c, tc.label)
if err != nil {
t.Fatal(errors.Wrap(err, "creating book"))
}
@ -78,13 +73,13 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
if err := db.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
if err := testutils.DB.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
t.Fatal(errors.Wrap(err, "counting books"))
}
if err := db.First(&bookRecord).Error; err != nil {
if err := testutils.DB.First(&bookRecord).Error; err != nil {
t.Fatal(errors.Wrap(err, "finding book"))
}
if err := db.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
if err := testutils.DB.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
t.Fatal(errors.Wrap(err, "finding user"))
}
@ -124,18 +119,17 @@ func TestDeleteBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
book := database.Book{UserID: user.ID, Label: "js", Deleted: false}
testutils.MustExec(t, db.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
testutils.MustExec(t, testutils.DB.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
tx := db.Begin()
tx := testutils.DB.Begin()
ret, err := DeleteBook(tx, user, book)
if err != nil {
tx.Rollback()
@ -147,9 +141,9 @@ func TestDeleteBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, 1, "book count mismatch")
assert.Equal(t, bookRecord.UserID, user.ID, "book user_id mismatch")
@ -202,20 +196,19 @@ func TestUpdateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
c := clock.NewMock()
b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel}
testutils.MustExec(t, db.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
testutils.MustExec(t, testutils.DB.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
tx := db.Begin()
tx := testutils.DB.Begin()
book, err := UpdateBook(tx, c, user, b, tc.payloadLabel)
if err != nil {
@ -228,9 +221,9 @@ func TestUpdateBook(t *testing.T) {
var bookCount int
var bookRecord database.Book
var userRecord database.User
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, 1, "book count mismatch")

View file

@ -28,10 +28,6 @@ import (
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
}
func TestIncremenetUserUSN(t *testing.T) {
testCases := []struct {
maxUSN int
@ -51,13 +47,12 @@ func TestIncremenetUserUSN(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
// execute
tx := db.Begin()
tx := testutils.DB.Begin()
nextUSN, err := incrementUserUSN(tx, user.ID)
if err != nil {
t.Fatal(errors.Wrap(err, "incrementing the user usn"))
@ -66,7 +61,7 @@ func TestIncremenetUserUSN(t *testing.T) {
// test
var userRecord database.User
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, userRecord.MaxUSN, tc.expectedMaxUSN, fmt.Sprintf("user max_usn mismatch for case %d", idx))
assert.Equal(t, nextUSN, tc.expectedMaxUSN, fmt.Sprintf("next_usn mismatch for case %d", idx))

View file

@ -0,0 +1,17 @@
package operations
import (
"os"
"testing"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestMain(m *testing.M) {
testutils.InitTestDB()
code := m.Run()
testutils.ClearData()
os.Exit(code)
}

View file

@ -29,8 +29,7 @@ import (
// CreateNote creates a note with the next usn and updates the user's max_usn.
// It returns the created note.
func CreateNote(user database.User, clock clock.Clock, bookUUID, content string, addedOn *int64, editedOn *int64, public bool) (database.Note, error) {
db := database.DBConn
func CreateNote(db *gorm.DB, user database.User, clock clock.Clock, bookUUID, content string, addedOn *int64, editedOn *int64, public bool) (database.Note, error) {
tx := db.Begin()
nextUSN, err := incrementUserUSN(tx, user.ID)
@ -163,14 +162,12 @@ func DeleteNote(tx *gorm.DB, user database.User, note database.Note) (database.N
}
// GetNote retrieves a note for the given user
func GetNote(uuid string, user database.User) (database.Note, bool, error) {
func GetNote(db *gorm.DB, uuid string, user database.User) (database.Note, bool, error) {
zeroNote := database.Note{}
if !helpers.ValidateUUID(uuid) {
return zeroNote, false, nil
}
db := database.DBConn
conn := db.Where("notes.uuid = ? AND deleted = ?", uuid, false)
conn = database.PreloadNote(conn)

View file

@ -30,10 +30,6 @@ import (
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
}
func TestCreateNote(t *testing.T) {
serverTime := time.Date(2017, time.March, 14, 21, 15, 0, 0, time.UTC)
mockClock := clock.NewMock()
@ -79,19 +75,18 @@ func TestCreateNote(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
tx := db.Begin()
if _, err := CreateNote(user, mockClock, b1.UUID, "note content", tc.addedOn, tc.editedOn, false); err != nil {
tx := testutils.DB.Begin()
if _, err := CreateNote(testutils.DB, user, mockClock, b1.UUID, "note content", tc.addedOn, tc.editedOn, false); err != nil {
tx.Rollback()
t.Fatal(errors.Wrap(err, "deleting note"))
}
@ -101,10 +96,10 @@ func TestCreateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), fmt.Sprintf("counting notes for test case %d", idx))
testutils.MustExec(t, db.First(&noteRecord), fmt.Sprintf("finding note for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), fmt.Sprintf("counting notes for test case %d", idx))
testutils.MustExec(t, testutils.DB.First(&noteRecord), fmt.Sprintf("finding note for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, 1, "book count mismatch")
assert.Equal(t, noteCount, 1, "note count mismatch")
@ -139,25 +134,24 @@ func TestUpdateNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
testutils.MustExec(t, db.Save(&b1), "preparing b1 for test case")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1 for test case")
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note), "preparing note for test case")
testutils.MustExec(t, testutils.DB.Save(&note), "preparing note for test case")
c := clock.NewMock()
content := "updated test content"
public := true
tx := db.Begin()
tx := testutils.DB.Begin()
if _, err := UpdateNote(tx, user, c, note, &UpdateNoteParams{
Content: &content,
Public: &public,
@ -171,10 +165,10 @@ func TestUpdateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes for test case")
testutils.MustExec(t, db.First(&noteRecord), "finding note for test case")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes for test case")
testutils.MustExec(t, testutils.DB.First(&noteRecord), "finding note for test case")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
expectedUSN := tc.userUSN + 1
assert.Equal(t, bookCount, 1, "book count mismatch")
@ -211,21 +205,20 @@ func TestDeleteNote(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note), fmt.Sprintf("preparing note for test case %d", idx))
testutils.MustExec(t, testutils.DB.Save(&note), fmt.Sprintf("preparing note for test case %d", idx))
tx := db.Begin()
tx := testutils.DB.Begin()
ret, err := DeleteNote(tx, user, note)
if err != nil {
tx.Rollback()
@ -237,9 +230,9 @@ func TestDeleteNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), fmt.Sprintf("counting notes for test case %d", idx))
testutils.MustExec(t, db.First(&noteRecord), fmt.Sprintf("finding note for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), fmt.Sprintf("counting notes for test case %d", idx))
testutils.MustExec(t, testutils.DB.First(&noteRecord), fmt.Sprintf("finding note for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, noteCount, 1, "note count mismatch")
@ -261,14 +254,13 @@ func TestGetNote(t *testing.T) {
user := testutils.SetupUserData()
anotherUser := testutils.SetupUserData()
db := database.DBConn
defer testutils.ClearData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
privateNote := database.Note{
UserID: user.ID,
@ -277,7 +269,7 @@ func TestGetNote(t *testing.T) {
Deleted: false,
Public: false,
}
testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
UserID: user.ID,
@ -286,11 +278,11 @@ func TestGetNote(t *testing.T) {
Deleted: false,
Public: true,
}
testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote")
testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote")
var privateNoteRecord, publicNoteRecord database.Note
testutils.MustExec(t, db.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote")
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote")
testCases := []struct {
name string
@ -338,7 +330,7 @@ func TestGetNote(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
note, ok, err := GetNote(tc.note.UUID, tc.user)
note, ok, err := GetNote(testutils.DB, tc.note.UUID, tc.user)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
@ -352,14 +344,13 @@ func TestGetNote(t *testing.T) {
func TestGetNote_nonexistent(t *testing.T) {
user := testutils.SetupUserData()
db := database.DBConn
defer testutils.ClearData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
n1UUID := "4fd19336-671e-4ff3-8f22-662b80e22edc"
n1 := database.Note{
@ -370,10 +361,10 @@ func TestGetNote_nonexistent(t *testing.T) {
Deleted: false,
Public: false,
}
testutils.MustExec(t, db.Save(&n1), "preparing n1")
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
nonexistentUUID := "4fd19336-671e-4ff3-8f22-662b80e22edd"
note, ok, err := GetNote(nonexistentUUID, user)
note, ok, err := GetNote(testutils.DB, nonexistentUUID, user)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}

View file

@ -22,6 +22,7 @@ import (
"github.com/dnote/dnote/pkg/server/database"
"github.com/pkg/errors"
"github.com/jinzhu/gorm"
"github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/sub"
)
@ -66,9 +67,7 @@ func ReactivateSub(subscriptionID string, user database.User) error {
}
// MarkUnsubscribed marks the user unsubscribed
func MarkUnsubscribed(stripeCustomerID string) error {
db := database.DBConn
func MarkUnsubscribed(db *gorm.DB, stripeCustomerID string) error {
var user database.User
if err := db.Where("stripe_customer_id = ?", stripeCustomerID).First(&user).Error; err != nil {
return errors.Wrap(err, "finding user")

View file

@ -104,8 +104,7 @@ func createDefaultRepetitionRule(user database.User, tx *gorm.DB) error {
}
// CreateUser creates a user
func CreateUser(email, password string) (database.User, error) {
db := database.DBConn
func CreateUser(db *gorm.DB, email, password string) (database.User, error) {
tx := db.Begin()
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)

View file

@ -19,6 +19,7 @@
package permissions
import (
"os"
"testing"
"github.com/dnote/dnote/pkg/assert"
@ -26,22 +27,26 @@ import (
"github.com/dnote/dnote/pkg/server/testutils"
)
func init() {
func TestMain(m *testing.M) {
testutils.InitTestDB()
code := m.Run()
testutils.ClearData()
os.Exit(code)
}
func TestViewNote(t *testing.T) {
user := testutils.SetupUserData()
anotherUser := testutils.SetupUserData()
db := database.DBConn
defer testutils.ClearData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
privateNote := database.Note{
UserID: user.ID,
@ -50,7 +55,7 @@ func TestViewNote(t *testing.T) {
Deleted: false,
Public: false,
}
testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
UserID: user.ID,
@ -59,7 +64,7 @@ func TestViewNote(t *testing.T) {
Deleted: false,
Public: true,
}
testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote")
testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote")
t.Run("owner accessing private note", func(t *testing.T) {
result := ViewNote(&user, privateNote)

View file

@ -32,6 +32,7 @@ import (
"time"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/dbconn"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stripe/stripe-go"
@ -42,31 +43,67 @@ func init() {
rand.Seed(time.Now().UnixNano())
}
// DB is the database connection to a test database
var DB *gorm.DB
// InitTestDB establishes connection pool with the test database specified by
// the environment variable configuration and initalizes a new schema
func InitTestDB() {
c := database.Config{
db := dbconn.Open(dbconn.Config{
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
})
database.InitSchema(db)
DB = db
}
// ClearData deletes all records from the database
func ClearData() {
if err := DB.Delete(&database.Book{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear books"))
}
if err := DB.Delete(&database.Note{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear notes"))
}
if err := DB.Delete(&database.Notification{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear notifications"))
}
if err := DB.Delete(&database.User{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear users"))
}
if err := DB.Delete(&database.Account{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear accounts"))
}
if err := DB.Delete(&database.Token{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear reset_tokens"))
}
if err := DB.Delete(&database.EmailPreference{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear reset_tokens"))
}
if err := DB.Delete(&database.Session{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear sessions"))
}
if err := DB.Delete(&database.Digest{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear digests"))
}
if err := DB.Delete(&database.RepetitionRule{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear digests"))
}
database.Open(c)
database.InitSchema()
}
// SetupUserData creates and returns a new user for testing purposes
func SetupUserData() database.User {
db := database.DBConn
user := database.User{
APIKey: "test-api-key",
Name: "user-name",
Cloud: true,
}
if err := db.Save(&user).Error; err != nil {
if err := DB.Save(&user).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare user"))
}
@ -75,8 +112,6 @@ func SetupUserData() database.User {
// SetupAccountData creates and returns a new account for the user
func SetupAccountData(user database.User, email, password string) database.Account {
db := database.DBConn
account := database.Account{
UserID: user.ID,
}
@ -90,7 +125,7 @@ func SetupAccountData(user database.User, email, password string) database.Accou
}
account.Password = database.ToNullString(string(hashedPassword))
if err := db.Save(&account).Error; err != nil {
if err := DB.Save(&account).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare account"))
}
@ -99,8 +134,6 @@ func SetupAccountData(user database.User, email, password string) database.Accou
// SetupClassicAccountData creates and returns a new account for the user
func SetupClassicAccountData(user database.User, email string) database.Account {
db := database.DBConn
// email: alice@example.com
// password: pass1234
// masterKey: WbUvagj9O6o1Z+4+7COjo7Uqm4MD2QE9EWFXne8+U+8=
@ -117,7 +150,7 @@ func SetupClassicAccountData(user database.User, email string) database.Account
account.Email = database.ToNullString(email)
}
if err := db.Save(&account).Error; err != nil {
if err := DB.Save(&account).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare account"))
}
@ -126,14 +159,12 @@ func SetupClassicAccountData(user database.User, email string) database.Account
// SetupSession creates and returns a new user session
func SetupSession(t *testing.T, user database.User) database.Session {
db := database.DBConn
session := database.Session{
Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
if err := db.Save(&session).Error; err != nil {
if err := DB.Save(&session).Error; err != nil {
t.Fatal(errors.Wrap(err, "Failed to prepare user"))
}
@ -142,56 +173,18 @@ func SetupSession(t *testing.T, user database.User) database.Session {
// SetupEmailPreferenceData creates and returns a new email frequency for a user
func SetupEmailPreferenceData(user database.User, digestWeekly bool) database.EmailPreference {
db := database.DBConn
frequency := database.EmailPreference{
UserID: user.ID,
DigestWeekly: digestWeekly,
}
if err := db.Save(&frequency).Error; err != nil {
if err := DB.Save(&frequency).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare email frequency"))
}
return frequency
}
// ClearData deletes all records from the database
func ClearData() {
db := database.DBConn
if err := db.Delete(&database.Book{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear books"))
}
if err := db.Delete(&database.Note{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear notes"))
}
if err := db.Delete(&database.Notification{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear notifications"))
}
if err := db.Delete(&database.User{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear users"))
}
if err := db.Delete(&database.Account{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear accounts"))
}
if err := db.Delete(&database.Token{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear reset_tokens"))
}
if err := db.Delete(&database.EmailPreference{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear reset_tokens"))
}
if err := db.Delete(&database.Session{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear sessions"))
}
if err := db.Delete(&database.Digest{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear digests"))
}
if err := db.Delete(&database.RepetitionRule{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear digests"))
}
}
// HTTPDo makes an HTTP request and returns a response
func HTTPDo(t *testing.T, req *http.Request) *http.Response {
hc := http.Client{
@ -213,8 +206,6 @@ func HTTPDo(t *testing.T, req *http.Request) *http.Response {
// HTTPAuthDo makes an HTTP request with an appropriate authorization header for a user
func HTTPAuthDo(t *testing.T, req *http.Request, user database.User) *http.Response {
db := database.DBConn
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
t.Fatal(errors.Wrap(err, "reading random bits"))
@ -225,7 +216,7 @@ func HTTPAuthDo(t *testing.T, req *http.Request, user database.User) *http.Respo
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 10 * 24),
}
if err := db.Save(&session).Error; err != nil {
if err := DB.Save(&session).Error; err != nil {
t.Fatal(errors.Wrap(err, "Failed to prepare user"))
}

View file

@ -24,6 +24,7 @@ import (
"net/http"
"regexp"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@ -58,8 +59,8 @@ func NewAppShell(content []byte) (AppShell, error) {
}
// Execute executes the index template
func (a AppShell) Execute(r *http.Request) ([]byte, error) {
data, err := a.getData(r)
func (a AppShell) Execute(r *http.Request, db *gorm.DB) ([]byte, error) {
data, err := a.getData(db, r)
if err != nil {
return nil, errors.Wrap(err, "getting data")
}
@ -72,11 +73,11 @@ func (a AppShell) Execute(r *http.Request) ([]byte, error) {
return buf.Bytes(), nil
}
func (a AppShell) getData(r *http.Request) (tmplData, error) {
func (a AppShell) getData(db *gorm.DB, r *http.Request) (tmplData, error) {
path := r.URL.Path
if ok, params := matchPath(path, notesPathRegex); ok {
p, err := a.newNotePage(r, params[0])
p, err := a.newNotePage(db, r, params[0])
if err != nil {
return tmplData{}, errors.Wrap(err, "instantiating note page")
}

View file

@ -29,10 +29,6 @@ import (
"github.com/pkg/errors"
)
func init() {
testutils.InitTestDB()
}
func TestAppShellExecute(t *testing.T) {
t.Run("home", func(t *testing.T) {
a, err := NewAppShell([]byte("<head><title>{{ .Title }}</title>{{ .MetaTags }}</head>"))
@ -45,7 +41,7 @@ func TestAppShellExecute(t *testing.T) {
t.Fatal(errors.Wrap(err, "preparing request"))
}
b, err := a.Execute(r)
b, err := a.Execute(r, testutils.DB)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
@ -55,21 +51,20 @@ func TestAppShellExecute(t *testing.T) {
t.Run("note", func(t *testing.T) {
defer testutils.ClearData()
db := database.DBConn
user := testutils.SetupUserData()
b1 := database.Book{
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, db.Save(&b1), "preparing b1")
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
n1 := database.Note{
UserID: user.ID,
BookUUID: b1.UUID,
Public: true,
Body: "n1 content",
}
testutils.MustExec(t, db.Save(&n1), "preparing note")
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing note")
a, err := NewAppShell([]byte("{{ .MetaTags }}"))
if err != nil {
@ -82,7 +77,7 @@ func TestAppShellExecute(t *testing.T) {
t.Fatal(errors.Wrap(err, "preparing request"))
}
b, err := a.Execute(r)
b, err := a.Execute(r, testutils.DB)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}

View file

@ -30,6 +30,7 @@ import (
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/handlers"
"github.com/dnote/dnote/pkg/server/operations"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
@ -51,13 +52,13 @@ type notePage struct {
T *template.Template
}
func (a AppShell) newNotePage(r *http.Request, noteUUID string) (notePage, error) {
user, _, err := handlers.AuthWithSession(r, nil)
func (a AppShell) newNotePage(db *gorm.DB, r *http.Request, noteUUID string) (notePage, error) {
user, _, err := handlers.AuthWithSession(db, r, nil)
if err != nil {
return notePage{}, errors.Wrap(err, "authenticating with session")
}
note, ok, err := operations.GetNote(noteUUID, user)
note, ok, err := operations.GetNote(db, noteUUID, user)
if !ok {
return notePage{}, ErrNotFound

View file

@ -0,0 +1,17 @@
package tmpl
import (
"os"
"testing"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestMain(m *testing.M) {
testutils.InitTestDB()
code := m.Run()
testutils.ClearData()
os.Exit(code)
}

View file

@ -24,20 +24,40 @@ import (
"github.com/dnote/dnote/pkg/server/handlers"
"github.com/dnote/dnote/pkg/server/tmpl"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
)
// Context contains contents of web assets
type Context struct {
DB *gorm.DB
IndexHTML []byte
RobotsTxt []byte
ServiceWorkerJs []byte
StaticFileSystem http.FileSystem
}
// GetRootHandler returns an HTTP handler that serves the app shell
func GetRootHandler(b []byte) http.HandlerFunc {
appShell, err := tmpl.NewAppShell(b)
// Handlers are a group of web handlers
type Handlers struct {
GetRoot http.HandlerFunc
GetRobots http.HandlerFunc
GetServiceWorker http.HandlerFunc
GetStatic http.Handler
}
// Init initializes the handlers
func Init(c Context) Handlers {
return Handlers{
GetRoot: getRootHandler(c),
GetRobots: getRobotsHandler(c),
GetServiceWorker: getSWHandler(c),
GetStatic: getStaticHandler(c),
}
}
// getRootHandler returns an HTTP handler that serves the app shell
func getRootHandler(c Context) http.HandlerFunc {
appShell, err := tmpl.NewAppShell(c.IndexHTML)
if err != nil {
panic(errors.Wrap(err, "initializing app shell"))
}
@ -46,7 +66,7 @@ func GetRootHandler(b []byte) http.HandlerFunc {
// index.html must not be cached
w.Header().Set("Cache-Control", "no-cache")
buf, err := appShell.Execute(r)
buf, err := appShell.Execute(r, c.DB)
if err != nil {
if errors.Cause(err) == tmpl.ErrNotFound {
handlers.RespondNotFound(w)
@ -60,24 +80,25 @@ func GetRootHandler(b []byte) http.HandlerFunc {
}
}
// GetRobotsHandler returns an HTTP handler that serves robots.txt
func GetRobotsHandler(b []byte) http.HandlerFunc {
// getRobotsHandler returns an HTTP handler that serves robots.txt
func getRobotsHandler(c Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache")
w.Write(b)
w.Write(c.RobotsTxt)
}
}
// GetSWHandler returns an HTTP handler that serves service worker
func GetSWHandler(b []byte) http.HandlerFunc {
// getSWHandler returns an HTTP handler that serves service worker
func getSWHandler(c Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Content-Type", "application/javascript")
w.Write(b)
w.Write(c.ServiceWorkerJs)
}
}
// GetStaticHandler returns an HTTP handler that serves static files from a filesystem
func GetStaticHandler(root http.FileSystem) http.Handler {
// getStaticHandler returns an HTTP handler that serves static files from a filesystem
func getStaticHandler(c Context) http.Handler {
root := c.StaticFileSystem
return http.StripPrefix("/static/", http.FileServer(root))
}

View file

@ -19,26 +19,27 @@
package main
import (
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/dbconn"
"github.com/dnote/dnote/pkg/server/helpers"
"os"
"time"
)
func main() {
c := database.Config{
db, err := dbconn.Open(dbconn.Config{
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
})
if err != nil {
panic(err)
}
database.Open(c)
db := database.DBConn
tx := db.Begin()
userID, err := helpers.GetDemoUserID()
userID, err := helpers.GetDemoUserID(db)
if err != nil {
panic(err)
}

View file

@ -13,6 +13,7 @@ if [ "${WATCH-false}" == true ]; then
while inotifywait --exclude .swp -e modify -r .; do go test ./... -cover -p 1; done;
set -e
else
# go test ./... -cover -p 1
go test ./... -cover -p 1
fi

View file

@ -1,2 +1,2 @@
User-agent: *
Disallow: /
Allow: /