mirror of
https://github.com/dnote/dnote
synced 2026-03-14 22:45:50 +01:00
Refactor to avoid global database variable (#313)
* Avoid global database * Fix Twitter summary card * Fix CLI test
This commit is contained in:
parent
ec6773dc45
commit
bd97209af8
56 changed files with 1056 additions and 1058 deletions
|
|
@ -19,11 +19,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
// Use postgres
|
||||
_ "github.com/lib/pq"
|
||||
|
|
@ -34,97 +30,13 @@ var (
|
|||
MigrationTableName = "migrations"
|
||||
)
|
||||
|
||||
// Config holds the connection configuration
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
Name string
|
||||
User string
|
||||
Password string
|
||||
}
|
||||
|
||||
// ErrConfigMissingHost is an error for an incomplete configuration missing the host
|
||||
var ErrConfigMissingHost = errors.New("Host is empty")
|
||||
|
||||
// ErrConfigMissingPort is an error for an incomplete configuration missing the port
|
||||
var ErrConfigMissingPort = errors.New("Port is empty")
|
||||
|
||||
// ErrConfigMissingName is an error for an incomplete configuration missing the name
|
||||
var ErrConfigMissingName = errors.New("Name is empty")
|
||||
|
||||
// ErrConfigMissingUser is an error for an incomplete configuration missing the user
|
||||
var ErrConfigMissingUser = errors.New("User is empty")
|
||||
|
||||
func validateConfig(c Config) error {
|
||||
if c.Host == "" {
|
||||
return ErrConfigMissingHost
|
||||
}
|
||||
if c.Port == "" {
|
||||
return ErrConfigMissingPort
|
||||
}
|
||||
if c.Name == "" {
|
||||
return ErrConfigMissingName
|
||||
}
|
||||
if c.User == "" {
|
||||
return ErrConfigMissingUser
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPGConnectionString(c Config) (string, error) {
|
||||
if err := validateConfig(c); err != nil {
|
||||
return "", errors.Wrap(err, "invalid database config")
|
||||
}
|
||||
|
||||
var sslmode string
|
||||
if os.Getenv("GO_ENV") == "PRODUCTION" && os.Getenv("DB_NOSSL") == "" {
|
||||
sslmode = "require"
|
||||
} else {
|
||||
sslmode = "disable"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"sslmode=%s host=%s port=%s dbname=%s user=%s password=%s",
|
||||
sslmode,
|
||||
c.Host,
|
||||
c.Port,
|
||||
c.Name,
|
||||
c.User,
|
||||
c.Password,
|
||||
), nil
|
||||
}
|
||||
|
||||
var (
|
||||
// DBConn is the connection handle for the database
|
||||
DBConn *gorm.DB
|
||||
)
|
||||
|
||||
// Open opens the connection with the database
|
||||
func Open(c Config) {
|
||||
connStr, err := getPGConnectionString(c)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
DBConn, err = gorm.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes database connection
|
||||
func Close() {
|
||||
DBConn.Close()
|
||||
}
|
||||
|
||||
// InitSchema migrates database schema to reflect the latest model definition
|
||||
func InitSchema() {
|
||||
if err := DBConn.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil {
|
||||
func InitSchema(db *gorm.DB) {
|
||||
if err := db.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := DBConn.AutoMigrate(
|
||||
if err := db.AutoMigrate(
|
||||
Note{},
|
||||
Book{},
|
||||
User{},
|
||||
|
|
|
|||
|
|
@ -22,20 +22,20 @@ import (
|
|||
"log"
|
||||
|
||||
"github.com/gobuffalo/packr/v2"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rubenv/sql-migrate"
|
||||
)
|
||||
|
||||
// Migrate runs the migrations
|
||||
func Migrate() error {
|
||||
func Migrate(db *gorm.DB) error {
|
||||
migrations := &migrate.PackrMigrationSource{
|
||||
Box: packr.New("migrations", "../database/migrations/"),
|
||||
}
|
||||
|
||||
migrate.SetTable(MigrationTableName)
|
||||
|
||||
db := DBConn.DB()
|
||||
n, err := migrate.Exec(db, "postgres", migrations, migrate.Up)
|
||||
n, err := migrate.Exec(db.DB(), "postgres", migrations, migrate.Up)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "running migrations")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/dbconn"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rubenv/sql-migrate"
|
||||
|
|
@ -43,20 +43,18 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
c := database.Config{
|
||||
Host: os.Getenv("DBHost"),
|
||||
Port: os.Getenv("DBPort"),
|
||||
Name: os.Getenv("DBName"),
|
||||
User: os.Getenv("DBUser"),
|
||||
Password: os.Getenv("DBPassword"),
|
||||
}
|
||||
database.Open(c)
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
db := database.DBConn
|
||||
db := dbconn.Open(dbconn.Config{
|
||||
Host: os.Getenv("DBHost"),
|
||||
Port: os.Getenv("DBPort"),
|
||||
Name: os.Getenv("DBName"),
|
||||
User: os.Getenv("DBUser"),
|
||||
Password: os.Getenv("DBPassword"),
|
||||
})
|
||||
|
||||
migrations := &migrate.FileMigrationSource{
|
||||
Dir: *migrationDir,
|
||||
|
|
|
|||
85
pkg/server/dbconn/dbconn.go
Normal file
85
pkg/server/dbconn/dbconn.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
package dbconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Config holds the connection configuration
|
||||
type Config struct {
|
||||
SkipSSL bool
|
||||
Host string
|
||||
Port string
|
||||
Name string
|
||||
User string
|
||||
Password string
|
||||
}
|
||||
|
||||
// ErrConfigMissingHost is an error for an incomplete configuration missing the host
|
||||
var ErrConfigMissingHost = errors.New("Host is empty")
|
||||
|
||||
// ErrConfigMissingPort is an error for an incomplete configuration missing the port
|
||||
var ErrConfigMissingPort = errors.New("Port is empty")
|
||||
|
||||
// ErrConfigMissingName is an error for an incomplete configuration missing the name
|
||||
var ErrConfigMissingName = errors.New("Name is empty")
|
||||
|
||||
// ErrConfigMissingUser is an error for an incomplete configuration missing the user
|
||||
var ErrConfigMissingUser = errors.New("User is empty")
|
||||
|
||||
func validateConfig(c Config) error {
|
||||
if c.Host == "" {
|
||||
return ErrConfigMissingHost
|
||||
}
|
||||
if c.Port == "" {
|
||||
return ErrConfigMissingPort
|
||||
}
|
||||
if c.Name == "" {
|
||||
return ErrConfigMissingName
|
||||
}
|
||||
if c.User == "" {
|
||||
return ErrConfigMissingUser
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPGConnectionString(c Config) (string, error) {
|
||||
if err := validateConfig(c); err != nil {
|
||||
return "", errors.Wrap(err, "invalid database config")
|
||||
}
|
||||
|
||||
var sslmode string
|
||||
if c.SkipSSL {
|
||||
sslmode = "disable"
|
||||
} else {
|
||||
sslmode = "require"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"sslmode=%s host=%s port=%s dbname=%s user=%s password=%s",
|
||||
sslmode,
|
||||
c.Host,
|
||||
c.Port,
|
||||
c.Name,
|
||||
c.User,
|
||||
c.Password,
|
||||
), nil
|
||||
}
|
||||
|
||||
// Open opens the connection with the database
|
||||
func Open(c Config) *gorm.DB {
|
||||
connStr, err := getPGConnectionString(c)
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "getting connection string"))
|
||||
}
|
||||
|
||||
conn, err := gorm.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "opening database connection"))
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
|
@ -16,7 +16,7 @@
|
|||
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package database
|
||||
package dbconn
|
||||
|
||||
import (
|
||||
"github.com/dnote/dnote/pkg/assert"
|
||||
|
|
@ -38,6 +38,17 @@ func TestValidateConfig(t *testing.T) {
|
|||
},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
input: Config{
|
||||
SkipSSL: true,
|
||||
Host: "mockHost",
|
||||
Port: "mockPort",
|
||||
Name: "mockName",
|
||||
User: "mockUser",
|
||||
Password: "mockPassword",
|
||||
},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
input: Config{
|
||||
Host: "mockHost",
|
||||
|
|
@ -24,10 +24,10 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/mailer"
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
|
@ -60,10 +60,8 @@ func (a *App) getMe(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
var account database.Account
|
||||
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
HandleError(w, "finding account", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -76,7 +74,7 @@ func (a *App) getMe(w http.ResponseWriter, r *http.Request) {
|
|||
User: session,
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
if err := operations.TouchLastLoginAt(user, tx); err != nil {
|
||||
tx.Rollback()
|
||||
// In case of an error, gracefully continue to avoid disturbing the service
|
||||
|
|
@ -92,8 +90,6 @@ type createResetTokenPayload struct {
|
|||
}
|
||||
|
||||
func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
var params createResetTokenPayload
|
||||
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
|
||||
http.Error(w, "invalid payload", http.StatusBadRequest)
|
||||
|
|
@ -101,7 +97,7 @@ func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
conn := db.Where("email = ?", params.Email).First(&account)
|
||||
conn := a.DB.Where("email = ?", params.Email).First(&account)
|
||||
if conn.RecordNotFound() {
|
||||
return
|
||||
}
|
||||
|
|
@ -127,7 +123,7 @@ func (a *App) createResetToken(w http.ResponseWriter, r *http.Request) {
|
|||
Type: database.TokenTypeResetPassword,
|
||||
}
|
||||
|
||||
if err := db.Save(&token).Error; err != nil {
|
||||
if err := a.DB.Save(&token).Error; err != nil {
|
||||
HandleError(w, errors.Wrap(err, "saving token").Error(), nil, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -156,8 +152,6 @@ type resetPasswordPayload struct {
|
|||
}
|
||||
|
||||
func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
var params resetPasswordPayload
|
||||
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
|
||||
http.Error(w, "invalid payload", http.StatusBadRequest)
|
||||
|
|
@ -165,7 +159,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var token database.Token
|
||||
conn := db.Where("value = ? AND type =? AND used_at IS NULL", params.Token, database.TokenTypeResetPassword).First(&token)
|
||||
conn := a.DB.Where("value = ? AND type =? AND used_at IS NULL", params.Token, database.TokenTypeResetPassword).First(&token)
|
||||
if conn.RecordNotFound() {
|
||||
http.Error(w, "invalid token", http.StatusBadRequest)
|
||||
return
|
||||
|
|
@ -186,7 +180,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(params.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
|
|
@ -196,7 +190,7 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
if err := db.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
|
||||
tx.Rollback()
|
||||
HandleError(w, errors.Wrap(err, "finding user").Error(), nil, http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -216,10 +210,10 @@ func (a *App) resetPassword(w http.ResponseWriter, r *http.Request) {
|
|||
tx.Commit()
|
||||
|
||||
var user database.User
|
||||
if err := db.Where("id = ?", account.UserID).First(&user).Error; err != nil {
|
||||
if err := a.DB.Where("id = ?", account.UserID).First(&user).Error; err != nil {
|
||||
HandleError(w, errors.Wrap(err, "finding user").Error(), nil, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
respondWithSession(w, user.ID, http.StatusOK)
|
||||
respondWithSession(a.DB, w, user.ID, http.StatusOK)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ import (
|
|||
)
|
||||
|
||||
func TestGetMe(t *testing.T) {
|
||||
testutils.InitTestDB()
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
|
@ -41,7 +41,7 @@ func TestGetMe(t *testing.T) {
|
|||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData()
|
||||
testutils.SetupAccountData(u, "alice@example.com", "somepassword")
|
||||
testutils.SetupAccountData( u, "alice@example.com", "somepassword")
|
||||
|
||||
dat := `{"email": "alice@example.com"}`
|
||||
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
|
||||
|
|
@ -53,23 +53,24 @@ func TestGetMe(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
|
||||
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
|
||||
}
|
||||
|
||||
func TestCreateResetToken(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData()
|
||||
testutils.SetupAccountData(u, "alice@example.com", "somepassword")
|
||||
testutils.SetupAccountData( u, "alice@example.com", "somepassword")
|
||||
|
||||
dat := `{"email": "alice@example.com"}`
|
||||
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
|
||||
|
|
@ -81,10 +82,10 @@ func TestCreateResetToken(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
|
||||
|
||||
var tokenCount int
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
|
||||
|
||||
var resetToken database.Token
|
||||
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token")
|
||||
|
||||
assert.Equal(t, tokenCount, 1, "reset_token count mismatch")
|
||||
assert.NotEqual(t, resetToken.Value, nil, "reset_token value mismatch")
|
||||
|
|
@ -92,17 +93,18 @@ func TestCreateResetToken(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("nonexistent email", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData()
|
||||
testutils.SetupAccountData(u, "alice@example.com", "somepassword")
|
||||
testutils.SetupAccountData( u, "alice@example.com", "somepassword")
|
||||
|
||||
dat := `{"email": "bob@example.com"}`
|
||||
req := testutils.MakeReq(server, "POST", "/reset-token", dat)
|
||||
|
|
@ -114,36 +116,37 @@ func TestCreateResetToken(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
|
||||
|
||||
var tokenCount int
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
|
||||
assert.Equal(t, tokenCount, 0, "reset_token count mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResetPassword(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData()
|
||||
a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword")
|
||||
a := testutils.SetupAccountData( u, "alice@example.com", "oldpassword")
|
||||
tok := database.Token{
|
||||
UserID: u.ID,
|
||||
Value: "MivFxYiSMMA4An9dP24DNQ==",
|
||||
Type: database.TokenTypeResetPassword,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
otherTok := database.Token{
|
||||
UserID: u.ID,
|
||||
Value: "somerandomvalue",
|
||||
Type: database.TokenTypeEmailVerification,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&otherTok), "preparing another token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&otherTok), "preparing another token")
|
||||
|
||||
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "newpassword"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
|
||||
|
|
@ -156,9 +159,9 @@ func TestResetPassword(t *testing.T) {
|
|||
|
||||
var resetToken, verificationToken database.Token
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
|
||||
testutils.MustExec(t, db.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token")
|
||||
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
|
||||
testutils.MustExec(t, testutils.DB.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account")
|
||||
|
||||
assert.NotEqual(t, resetToken.UsedAt, nil, "reset_token UsedAt mismatch")
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
|
||||
|
|
@ -167,23 +170,24 @@ func TestResetPassword(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("nonexistent token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData()
|
||||
a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
|
||||
a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
|
||||
tok := database.Token{
|
||||
UserID: u.ID,
|
||||
Value: "MivFxYiSMMA4An9dP24DNQ==",
|
||||
Type: database.TokenTypeResetPassword,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
dat := `{"token": "-ApMnyvpg59uOU5b-Kf5uQ==", "password": "oldpassword"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
|
||||
|
|
@ -196,8 +200,8 @@ func TestResetPassword(t *testing.T) {
|
|||
|
||||
var resetToken database.Token
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
|
||||
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account")
|
||||
|
||||
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
|
||||
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
|
||||
|
|
@ -205,24 +209,25 @@ func TestResetPassword(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData()
|
||||
a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
|
||||
a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
|
||||
tok := database.Token{
|
||||
UserID: u.ID,
|
||||
Value: "MivFxYiSMMA4An9dP24DNQ==",
|
||||
Type: database.TokenTypeResetPassword,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
|
||||
|
||||
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
|
||||
|
|
@ -235,24 +240,25 @@ func TestResetPassword(t *testing.T) {
|
|||
|
||||
var resetToken database.Token
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
|
||||
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
|
||||
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
|
||||
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
|
||||
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
|
||||
})
|
||||
|
||||
t.Run("used token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData()
|
||||
a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
|
||||
a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
|
||||
|
||||
usedAt := time.Now().Add(time.Hour * -11).UTC()
|
||||
tok := database.Token{
|
||||
|
|
@ -261,8 +267,8 @@ func TestResetPassword(t *testing.T) {
|
|||
Type: database.TokenTypeResetPassword,
|
||||
UsedAt: &usedAt,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
|
||||
|
||||
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
|
||||
|
|
@ -275,8 +281,8 @@ func TestResetPassword(t *testing.T) {
|
|||
|
||||
var resetToken database.Token
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
|
||||
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
|
||||
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
|
||||
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
|
||||
|
||||
if resetToken.UsedAt.Year() != usedAt.Year() ||
|
||||
|
|
@ -290,24 +296,25 @@ func TestResetPassword(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("using wrong type token: email_verification", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
u := testutils.SetupUserData()
|
||||
a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
|
||||
a := testutils.SetupAccountData( u, "alice@example.com", "somepassword")
|
||||
tok := database.Token{
|
||||
UserID: u.ID,
|
||||
Value: "MivFxYiSMMA4An9dP24DNQ==",
|
||||
Type: database.TokenTypeEmailVerification,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "Failed to prepare reset_token")
|
||||
testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "Failed to prepare reset_token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
|
||||
|
||||
dat := `{"token": "MivFxYiSMMA4An9dP24DNQ==", "password": "oldpassword"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/reset-password", dat)
|
||||
|
|
@ -320,8 +327,8 @@ func TestResetPassword(t *testing.T) {
|
|||
|
||||
var resetToken database.Token
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
|
||||
testutils.MustExec(t, db.Where("id = ?", a.ID).First(&account), "failed to find account")
|
||||
testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
|
||||
|
||||
assert.Equal(t, a.Password, account.Password, "password should not have been updated")
|
||||
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
|
||||
|
|
|
|||
|
|
@ -23,11 +23,11 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/crypt"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/log"
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/log"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
|
@ -39,15 +39,13 @@ func (a *App) classicMigrate(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
var account database.Account
|
||||
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
HandleError(w, "finding account", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := db.Model(&account).
|
||||
if err := a.DB.Model(&account).
|
||||
Update(map[string]interface{}{
|
||||
"salt": "",
|
||||
"auth_key_hash": "",
|
||||
|
|
@ -66,8 +64,6 @@ type PresigninResponse struct {
|
|||
}
|
||||
|
||||
func (a *App) classicPresignin(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
q := r.URL.Query()
|
||||
email := q.Get("email")
|
||||
if email == "" {
|
||||
|
|
@ -76,7 +72,7 @@ func (a *App) classicPresignin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
conn := db.Where("email = ?", email).First(&account)
|
||||
conn := a.DB.Where("email = ?", email).First(&account)
|
||||
if !conn.RecordNotFound() && conn.Error != nil {
|
||||
HandleError(w, "getting user", conn.Error, http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -106,8 +102,6 @@ type classicSigninPayload struct {
|
|||
}
|
||||
|
||||
func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
var params classicSigninPayload
|
||||
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
|
||||
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
|
||||
|
|
@ -120,7 +114,7 @@ func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
conn := db.Where("email = ?", params.Email).First(&account)
|
||||
conn := a.DB.Where("email = ?", params.Email).First(&account)
|
||||
if conn.RecordNotFound() {
|
||||
http.Error(w, ErrLoginFailure.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
|
|
@ -138,7 +132,7 @@ func (a *App) classicSignin(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
session, err := operations.CreateSession(db, account.UserID)
|
||||
session, err := operations.CreateSession(a.DB, account.UserID)
|
||||
if err != nil {
|
||||
HandleError(w, "creating session", nil, http.StatusBadRequest)
|
||||
return
|
||||
|
|
@ -169,10 +163,8 @@ func (a *App) classicGetMe(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
var account database.Account
|
||||
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
HandleError(w, "finding account", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -229,8 +221,6 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
var params classicSetPasswordPayload
|
||||
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
|
||||
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
|
||||
|
|
@ -238,7 +228,7 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
HandleError(w, "getting user", nil, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -249,7 +239,7 @@ func (a *App) classicSetPassword(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := db.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
|
||||
if err := a.DB.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
|
||||
http.Error(w, errors.Wrap(err, "updating password").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -265,8 +255,7 @@ func (a *App) classicGetNotes(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var notes []database.Note
|
||||
db := database.DBConn
|
||||
if err := db.Where("user_id = ? AND encrypted = true", user.ID).Find(¬es).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ? AND encrypted = true", user.ID).Find(¬es).Error; err != nil {
|
||||
HandleError(w, "finding notes", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,26 +22,16 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/dnote/dnote/pkg/assert"
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/mailer"
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
|
||||
templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
|
||||
mailer.InitTemplates(&templatePath)
|
||||
}
|
||||
|
||||
func TestClassicPresignin(t *testing.T) {
|
||||
db := database.DBConn
|
||||
defer testutils.ClearData()
|
||||
|
||||
alice := database.Account{
|
||||
|
|
@ -52,8 +42,8 @@ func TestClassicPresignin(t *testing.T) {
|
|||
Email: database.ToNullString("bob@example.com"),
|
||||
ClientKDFIteration: 200000,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&alice), "saving alice")
|
||||
testutils.MustExec(t, db.Save(&bob), "saving bob")
|
||||
testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
|
||||
testutils.MustExec(t, testutils.DB.Save(&bob), "saving bob")
|
||||
|
||||
testCases := []struct {
|
||||
email string
|
||||
|
|
@ -121,12 +111,11 @@ func TestClassicPresignin_MissingParams(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClassicSignin(t *testing.T) {
|
||||
db := database.DBConn
|
||||
defer testutils.ClearData()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
alice := testutils.SetupClassicAccountData(user, "alice@example.com")
|
||||
testutils.MustExec(t, db.Save(&alice), "saving alice")
|
||||
testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
|
@ -145,8 +134,8 @@ func TestClassicSignin(t *testing.T) {
|
|||
|
||||
var sessionCount int
|
||||
var session database.Session
|
||||
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, db.First(&session), "getting session")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, testutils.DB.First(&session), "getting session")
|
||||
|
||||
var got SessionResponse
|
||||
if err := json.NewDecoder(res.Body).Decode(&got); err != nil {
|
||||
|
|
@ -165,7 +154,7 @@ func TestClassicSignin(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClassicSignin_Failure(t *testing.T) {
|
||||
db := database.DBConn
|
||||
|
||||
defer testutils.ClearData()
|
||||
|
||||
//password: correctbattery
|
||||
|
|
@ -183,8 +172,8 @@ func TestClassicSignin_Failure(t *testing.T) {
|
|||
// plain authKey: DN4d/teaq1I2bVYZ7QWaah4Fu7q2y2N4yJNZk76hFHw=
|
||||
AuthKeyHash: "fGOMHHAw9G7CH4Gv2EM1ZcZZklC1a55fS3QJ0qQVp4k=",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&alice), "saving alice")
|
||||
testutils.MustExec(t, db.Save(&bob), "saving bob")
|
||||
testutils.MustExec(t, testutils.DB.Save(&alice), "saving alice")
|
||||
testutils.MustExec(t, testutils.DB.Save(&bob), "saving bob")
|
||||
|
||||
testCases := []struct {
|
||||
email string
|
||||
|
|
|
|||
|
|
@ -25,13 +25,13 @@ import (
|
|||
"github.com/dnote/dnote/pkg/assert"
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
func TestCheckHealth(t *testing.T) {
|
||||
defer testutils.ClearData()
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
DB: &gorm.DB{},
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
|
|||
20
pkg/server/handlers/main_test.go
Normal file
20
pkg/server/handlers/main_test.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/mailer"
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutils.InitTestDB()
|
||||
templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
|
||||
mailer.InitTemplates(&templatePath)
|
||||
|
||||
code := m.Run()
|
||||
testutils.ClearData()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
|
@ -86,9 +86,7 @@ func parseSearchQuery(q url.Values) string {
|
|||
return escapeSearchQuery(searchStr)
|
||||
}
|
||||
|
||||
func getNoteBaseQuery(noteUUID string, search string) *gorm.DB {
|
||||
db := database.DBConn
|
||||
|
||||
func getNoteBaseQuery(db *gorm.DB, noteUUID string, search string) *gorm.DB {
|
||||
var conn *gorm.DB
|
||||
if search != "" {
|
||||
conn = selectFTSFields(db, search, &ftsParams{HighlightAll: true})
|
||||
|
|
@ -102,7 +100,7 @@ func getNoteBaseQuery(noteUUID string, search string) *gorm.DB {
|
|||
}
|
||||
|
||||
func (a *App) getNote(w http.ResponseWriter, r *http.Request) {
|
||||
user, _, err := AuthWithSession(r, nil)
|
||||
user, _, err := AuthWithSession(a.DB, r, nil)
|
||||
if err != nil {
|
||||
HandleError(w, "authenticating", err, http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -111,7 +109,7 @@ func (a *App) getNote(w http.ResponseWriter, r *http.Request) {
|
|||
vars := mux.Vars(r)
|
||||
noteUUID := vars["noteUUID"]
|
||||
|
||||
note, ok, err := operations.GetNote(noteUUID, user)
|
||||
note, ok, err := operations.GetNote(a.DB, noteUUID, user)
|
||||
if !ok {
|
||||
RespondNotFound(w)
|
||||
return
|
||||
|
|
@ -145,17 +143,17 @@ func (a *App) getNotes(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
query := r.URL.Query()
|
||||
|
||||
respondGetNotes(user.ID, query, w)
|
||||
respondGetNotes(a.DB, user.ID, query, w)
|
||||
}
|
||||
|
||||
func respondGetNotes(userID int, query url.Values, w http.ResponseWriter) {
|
||||
func respondGetNotes(db *gorm.DB, userID int, query url.Values, w http.ResponseWriter) {
|
||||
q, err := parseGetNotesQuery(query)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
conn := getNotesBaseQuery(userID, q)
|
||||
conn := getNotesBaseQuery(db, userID, q)
|
||||
|
||||
var total int
|
||||
if err := conn.Model(database.Note{}).Count(&total).Error; err != nil {
|
||||
|
|
@ -274,9 +272,7 @@ func getDateBounds(year, month int) (int64, int64) {
|
|||
return lower, upper
|
||||
}
|
||||
|
||||
func getNotesBaseQuery(userID int, q getNotesQuery) *gorm.DB {
|
||||
db := database.DBConn
|
||||
|
||||
func getNotesBaseQuery(db *gorm.DB, userID int, q getNotesQuery) *gorm.DB {
|
||||
conn := db.Where(
|
||||
"notes.user_id = ? AND notes.deleted = ? AND notes.encrypted = ?",
|
||||
userID, false, q.Encrypted,
|
||||
|
|
@ -317,8 +313,7 @@ func (a *App) legacyGetNotes(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var notes []database.Note
|
||||
db := database.DBConn
|
||||
if err := db.Where("user_id = ? AND encrypted = true", user.ID).Find(¬es).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ? AND encrypted = true", user.ID).Find(¬es).Error; err != nil {
|
||||
HandleError(w, "finding notes", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,16 +28,12 @@ import (
|
|||
|
||||
"github.com/dnote/dnote/pkg/assert"
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func getExpectedNotePayload(n database.Note, b database.Book, u database.User) presenters.Note {
|
||||
return presenters.Note{
|
||||
UUID: n.UUID,
|
||||
|
|
@ -59,11 +55,12 @@ func getExpectedNotePayload(n database.Note, b database.Book, u database.User) p
|
|||
}
|
||||
|
||||
func TestGetNotes(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -75,17 +72,17 @@ func TestGetNotes(t *testing.T) {
|
|||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
b2 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "css",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b2), "preparing b2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
|
||||
b3 := database.Book{
|
||||
UserID: anotherUser.ID,
|
||||
Label: "css",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b3), "preparing b3")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
|
||||
|
||||
n1 := database.Note{
|
||||
UserID: user.ID,
|
||||
|
|
@ -95,7 +92,7 @@ func TestGetNotes(t *testing.T) {
|
|||
Deleted: false,
|
||||
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n1), "preparing n1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
|
||||
n2 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
|
|
@ -104,7 +101,7 @@ func TestGetNotes(t *testing.T) {
|
|||
Deleted: false,
|
||||
AddedOn: time.Date(2018, time.August, 11, 22, 0, 0, 0, time.UTC).UnixNano(),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n2), "preparing n2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
|
||||
n3 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
|
|
@ -113,7 +110,7 @@ func TestGetNotes(t *testing.T) {
|
|||
Deleted: false,
|
||||
AddedOn: time.Date(2017, time.January, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n3), "preparing n3")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
|
||||
n4 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b2.UUID,
|
||||
|
|
@ -122,7 +119,7 @@ func TestGetNotes(t *testing.T) {
|
|||
Deleted: false,
|
||||
AddedOn: time.Date(2018, time.September, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n4), "preparing n4")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n4), "preparing n4")
|
||||
n5 := database.Note{
|
||||
UserID: anotherUser.ID,
|
||||
BookUUID: b3.UUID,
|
||||
|
|
@ -131,7 +128,7 @@ func TestGetNotes(t *testing.T) {
|
|||
Deleted: false,
|
||||
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n5), "preparing n5")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n5), "preparing n5")
|
||||
n6 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
|
|
@ -140,7 +137,7 @@ func TestGetNotes(t *testing.T) {
|
|||
Deleted: true,
|
||||
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n6), "preparing n6")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n6), "preparing n6")
|
||||
|
||||
// Execute
|
||||
req := testutils.MakeReq(server, "GET", "/notes?year=2018&month=8", "")
|
||||
|
|
@ -155,8 +152,8 @@ func TestGetNotes(t *testing.T) {
|
|||
}
|
||||
|
||||
var n2Record, n1Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
|
||||
|
||||
expected := GetNotesResponse{
|
||||
Notes: []presenters.Note{
|
||||
|
|
@ -170,11 +167,12 @@ func TestGetNotes(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetNote(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -186,7 +184,7 @@ func TestGetNote(t *testing.T) {
|
|||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
|
||||
privateNote := database.Note{
|
||||
UserID: user.ID,
|
||||
|
|
@ -194,20 +192,20 @@ func TestGetNote(t *testing.T) {
|
|||
Body: "privateNote content",
|
||||
Public: false,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
|
||||
testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
|
||||
publicNote := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
Body: "publicNote content",
|
||||
Public: true,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&publicNote), "preparing publicNote")
|
||||
testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing publicNote")
|
||||
deletedNote := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
Deleted: true,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&deletedNote), "preparing publicNote")
|
||||
testutils.MustExec(t, testutils.DB.Save(&deletedNote), "preparing publicNote")
|
||||
|
||||
t.Run("owner accessing private note", func(t *testing.T) {
|
||||
// Execute
|
||||
|
|
@ -224,7 +222,7 @@ func TestGetNote(t *testing.T) {
|
|||
}
|
||||
|
||||
var n1Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", privateNote.UUID).First(&n1Record), "finding n1Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", privateNote.UUID).First(&n1Record), "finding n1Record")
|
||||
|
||||
expected := getExpectedNotePayload(n1Record, b1, user)
|
||||
assert.DeepEqual(t, payload, expected, "payload mismatch")
|
||||
|
|
@ -245,7 +243,7 @@ func TestGetNote(t *testing.T) {
|
|||
}
|
||||
|
||||
var n2Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
|
||||
|
||||
expected := getExpectedNotePayload(n2Record, b1, user)
|
||||
assert.DeepEqual(t, payload, expected, "payload mismatch")
|
||||
|
|
@ -266,7 +264,7 @@ func TestGetNote(t *testing.T) {
|
|||
}
|
||||
|
||||
var n2Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
|
||||
|
||||
expected := getExpectedNotePayload(n2Record, b1, user)
|
||||
assert.DeepEqual(t, payload, expected, "payload mismatch")
|
||||
|
|
@ -304,7 +302,7 @@ func TestGetNote(t *testing.T) {
|
|||
}
|
||||
|
||||
var n2Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
|
||||
|
||||
expected := getExpectedNotePayload(n2Record, b1, user)
|
||||
assert.DeepEqual(t, payload, expected, "payload mismatch")
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
|
@ -45,9 +45,8 @@ func (a *App) getRepetitionRule(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
var repetitionRule database.RepetitionRule
|
||||
if err := db.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").Find(&repetitionRule).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").Find(&repetitionRule).Error; err != nil {
|
||||
HandleError(w, "getting repetition rules", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -63,9 +62,8 @@ func (a *App) getRepetitionRules(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
var repetitionRules []database.RepetitionRule
|
||||
if err := db.Where("user_id = ?", user.ID).Preload("Books").Order("last_active DESC").Find(&repetitionRules).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ?", user.ID).Preload("Books").Order("last_active DESC").Find(&repetitionRules).Error; err != nil {
|
||||
HandleError(w, "getting repetition rules", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -288,9 +286,8 @@ func (a *App) createRepetitionRule(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
var books []database.Book
|
||||
if err := db.Where("user_id = ? AND uuid IN (?)", user.ID, params.GetBookUUIDs()).Find(&books).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ? AND uuid IN (?)", user.ID, params.GetBookUUIDs()).Find(&books).Error; err != nil {
|
||||
HandleError(w, "finding books", nil, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -313,7 +310,7 @@ func (a *App) createRepetitionRule(w http.ResponseWriter, r *http.Request) {
|
|||
NoteCount: params.GetNoteCount(),
|
||||
Enabled: params.GetEnabled(),
|
||||
}
|
||||
if err := db.Create(&record).Error; err != nil {
|
||||
if err := a.DB.Create(&record).Error; err != nil {
|
||||
HandleError(w, "creating a repetition rule", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -346,10 +343,8 @@ func (a *App) deleteRepetitionRule(w http.ResponseWriter, r *http.Request) {
|
|||
vars := mux.Vars(r)
|
||||
repetitionRuleUUID := vars["repetitionRuleUUID"]
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
var rule database.RepetitionRule
|
||||
conn := db.Where("uuid = ? AND user_id = ?", repetitionRuleUUID, user.ID).First(&rule)
|
||||
conn := a.DB.Where("uuid = ? AND user_id = ?", repetitionRuleUUID, user.ID).First(&rule)
|
||||
|
||||
if conn.RecordNotFound() {
|
||||
http.Error(w, "Not found", http.StatusNotFound)
|
||||
|
|
@ -359,7 +354,7 @@ func (a *App) deleteRepetitionRule(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := db.Exec("DELETE from repetition_rules WHERE uuid = ?", rule.UUID).Error; err != nil {
|
||||
if err := a.DB.Exec("DELETE from repetition_rules WHERE uuid = ?", rule.UUID).Error; err != nil {
|
||||
HandleError(w, "deleting the repetition rule", err, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
|
|
@ -382,8 +377,7 @@ func (a *App) updateRepetitionRule(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
var repetitionRule database.RepetitionRule
|
||||
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, repetitionRuleUUID).Preload("Books").First(&repetitionRule).Error; err != nil {
|
||||
|
|
|
|||
|
|
@ -27,22 +27,19 @@ import (
|
|||
|
||||
"github.com/dnote/dnote/pkg/assert"
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func TestGetRepetitionRule(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -53,7 +50,7 @@ func TestGetRepetitionRule(t *testing.T) {
|
|||
USN: 11,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
|
||||
|
||||
r1 := database.RepetitionRule{
|
||||
Title: "Rule 1",
|
||||
|
|
@ -66,7 +63,7 @@ func TestGetRepetitionRule(t *testing.T) {
|
|||
Books: []database.Book{b1},
|
||||
NoteCount: 5,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
|
||||
|
||||
// Execute
|
||||
req := testutils.MakeReq(server, "GET", fmt.Sprintf("/repetition_rules/%s", r1.UUID), "")
|
||||
|
|
@ -81,9 +78,9 @@ func TestGetRepetitionRule(t *testing.T) {
|
|||
}
|
||||
|
||||
var r1Record database.RepetitionRule
|
||||
testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
|
||||
var b1Record database.Book
|
||||
testutils.MustExec(t, db.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
|
||||
|
||||
expected := presenters.RepetitionRule{
|
||||
UUID: r1Record.UUID,
|
||||
|
|
@ -112,11 +109,12 @@ func TestGetRepetitionRule(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetRepetitionRules(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -127,7 +125,7 @@ func TestGetRepetitionRules(t *testing.T) {
|
|||
USN: 11,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
|
||||
|
||||
r1 := database.RepetitionRule{
|
||||
Title: "Rule 1",
|
||||
|
|
@ -140,7 +138,7 @@ func TestGetRepetitionRules(t *testing.T) {
|
|||
Books: []database.Book{b1},
|
||||
NoteCount: 5,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
|
||||
|
||||
r2 := database.RepetitionRule{
|
||||
Title: "Rule 2",
|
||||
|
|
@ -153,7 +151,7 @@ func TestGetRepetitionRules(t *testing.T) {
|
|||
Books: []database.Book{},
|
||||
NoteCount: 5,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r2), "preparing rule2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r2), "preparing rule2")
|
||||
|
||||
// Execute
|
||||
req := testutils.MakeReq(server, "GET", "/repetition_rules", "")
|
||||
|
|
@ -168,10 +166,10 @@ func TestGetRepetitionRules(t *testing.T) {
|
|||
}
|
||||
|
||||
var r1Record, r2Record database.RepetitionRule
|
||||
testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", r2.UUID).First(&r2Record), "finding r2Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&r1Record), "finding r1Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", r2.UUID).First(&r2Record), "finding r2Record")
|
||||
var b1Record database.Book
|
||||
testutils.MustExec(t, db.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", b1.UUID).First(&b1Record), "finding b1Record")
|
||||
|
||||
expected := []presenters.RepetitionRule{
|
||||
{
|
||||
|
|
@ -217,8 +215,8 @@ func TestGetRepetitionRules(t *testing.T) {
|
|||
|
||||
func TestCreateRepetitionRules(t *testing.T) {
|
||||
t.Run("all books", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
c := clock.NewMock()
|
||||
|
|
@ -226,6 +224,7 @@ func TestCreateRepetitionRules(t *testing.T) {
|
|||
c.SetNow(t0)
|
||||
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: c,
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -250,11 +249,11 @@ func TestCreateRepetitionRules(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
|
||||
|
||||
var ruleCount int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
|
||||
|
||||
var rule database.RepetitionRule
|
||||
testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
|
||||
testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
|
||||
|
||||
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
|
||||
assert.Equal(t, rule.Title, "Rule 1", "rule Title mismatch")
|
||||
|
|
@ -275,8 +274,8 @@ func TestCreateRepetitionRules(t *testing.T) {
|
|||
}
|
||||
for _, tc := range bookDomainTestCases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
c := clock.NewMock()
|
||||
|
|
@ -284,6 +283,7 @@ func TestCreateRepetitionRules(t *testing.T) {
|
|||
c.SetNow(t0)
|
||||
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: c,
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -294,7 +294,7 @@ func TestCreateRepetitionRules(t *testing.T) {
|
|||
UserID: user.ID,
|
||||
Label: "css",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
|
||||
// Execute
|
||||
dat := fmt.Sprintf(`{
|
||||
|
|
@ -314,14 +314,14 @@ func TestCreateRepetitionRules(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
|
||||
|
||||
var ruleCount int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
|
||||
|
||||
var rule database.RepetitionRule
|
||||
testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
|
||||
testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
|
||||
|
||||
var b1Record database.Book
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
|
||||
|
||||
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
|
||||
assert.Equal(t, rule.Title, "Rule 1", "rule Title mismatch")
|
||||
|
|
@ -339,14 +339,15 @@ func TestCreateRepetitionRules(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUpdateRepetitionRules(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
c := clock.NewMock()
|
||||
t0 := time.Date(2009, time.November, 1, 2, 3, 4, 5, time.UTC)
|
||||
c.SetNow(t0)
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: c,
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -367,13 +368,13 @@ func TestUpdateRepetitionRules(t *testing.T) {
|
|||
Books: []database.Book{},
|
||||
NoteCount: 20,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing r1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
USN: 11,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
|
||||
|
||||
dat := fmt.Sprintf(`{
|
||||
"title": "Rule 1 - edited",
|
||||
|
|
@ -393,14 +394,14 @@ func TestUpdateRepetitionRules(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "")
|
||||
|
||||
var totalRuleCount int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
|
||||
assert.Equalf(t, totalRuleCount, 1, "reperition rule count mismatch")
|
||||
|
||||
var rule database.RepetitionRule
|
||||
testutils.MustExec(t, db.Preload("Books").First(&rule), "finding b1Record")
|
||||
testutils.MustExec(t, testutils.DB.Preload("Books").First(&rule), "finding b1Record")
|
||||
|
||||
var b1Record database.Book
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1Record")
|
||||
|
||||
assert.NotEqual(t, rule.UUID, "", "rule UUID mismatch")
|
||||
assert.Equal(t, rule.Title, "Rule 1 - edited", "rule Title mismatch")
|
||||
|
|
@ -416,11 +417,12 @@ func TestUpdateRepetitionRules(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDeleteRepetitionRules(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -439,7 +441,7 @@ func TestDeleteRepetitionRules(t *testing.T) {
|
|||
Books: []database.Book{},
|
||||
NoteCount: 20,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing r1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
|
||||
|
||||
r2 := database.RepetitionRule{
|
||||
Title: "Rule 1",
|
||||
|
|
@ -452,7 +454,7 @@ func TestDeleteRepetitionRules(t *testing.T) {
|
|||
Books: []database.Book{},
|
||||
NoteCount: 20,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r2), "preparing r2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r2), "preparing r2")
|
||||
|
||||
endpoint := fmt.Sprintf("/repetition_rules/%s", r1.UUID)
|
||||
req := testutils.MakeReq(server, "DELETE", endpoint, "")
|
||||
|
|
@ -462,11 +464,11 @@ func TestDeleteRepetitionRules(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "")
|
||||
|
||||
var totalRuleCount int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&totalRuleCount), "counting rules")
|
||||
assert.Equalf(t, totalRuleCount, 1, "reperition rule count mismatch")
|
||||
|
||||
var r2Count int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Where("id = ?", r2.ID).Count(&r2Count), "counting r2")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Where("id = ?", r2.ID).Count(&r2Count), "counting r2")
|
||||
assert.Equalf(t, r2Count, 1, "r2 count mismatch")
|
||||
}
|
||||
|
||||
|
|
@ -541,11 +543,12 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
|
|||
|
||||
for idx, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case - create %d", idx), func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -560,13 +563,13 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
|
||||
|
||||
var ruleCount int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
assert.Equalf(t, ruleCount, 0, "reperition rule count mismatch")
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("test case %d - update", idx), func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
user := testutils.SetupUserData()
|
||||
|
|
@ -581,15 +584,16 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
|
|||
Books: []database.Book{},
|
||||
NoteCount: 20,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing r1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing r1")
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
USN: 11,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book1")
|
||||
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -602,7 +606,7 @@ func TestCreateUpdateRepetitionRules_BadRequest(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
|
||||
|
||||
var ruleCount int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
assert.Equalf(t, ruleCount, 1, "reperition rule count mismatch")
|
||||
})
|
||||
}
|
||||
|
|
@ -624,11 +628,12 @@ func TestCreateRepetitionRules_BadRequest(t *testing.T) {
|
|||
|
||||
for idx, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -643,7 +648,7 @@ func TestCreateRepetitionRules_BadRequest(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
|
||||
|
||||
var ruleCount int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Count(&ruleCount), "counting rules")
|
||||
assert.Equalf(t, ruleCount, 0, "reperition rule count mismatch")
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/log"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stripe/stripe-go"
|
||||
)
|
||||
|
|
@ -63,28 +64,6 @@ func parseAuthHeader(h string) (authHeader, error) {
|
|||
return parsed, nil
|
||||
}
|
||||
|
||||
func legacyAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
db := database.DBConn
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := r.Cookie("api_key")
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := c.Value
|
||||
var user database.User
|
||||
if db.Where("api_key = ?", apiKey).First(&user).RecordNotFound() {
|
||||
http.Error(w, "Invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), helpers.KeyUser, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// getSessionKeyFromCookie reads and returns a session key from the cookie sent by the
|
||||
// request. If no session key is found, it returns an empty string
|
||||
func getSessionKeyFromCookie(r *http.Request) (string, error) {
|
||||
|
|
@ -138,8 +117,7 @@ func getCredential(r *http.Request) (string, error) {
|
|||
}
|
||||
|
||||
// AuthWithSession performs user authentication with session
|
||||
func AuthWithSession(r *http.Request, p *AuthMiddlewareParams) (database.User, bool, error) {
|
||||
db := database.DBConn
|
||||
func AuthWithSession(db *gorm.DB, r *http.Request, p *AuthMiddlewareParams) (database.User, bool, error) {
|
||||
var user database.User
|
||||
|
||||
sessionKey, err := getCredential(r)
|
||||
|
|
@ -174,8 +152,7 @@ func AuthWithSession(r *http.Request, p *AuthMiddlewareParams) (database.User, b
|
|||
return user, true, nil
|
||||
}
|
||||
|
||||
func authWithToken(r *http.Request, tokenType string, p *AuthMiddlewareParams) (database.User, database.Token, bool, error) {
|
||||
db := database.DBConn
|
||||
func authWithToken(db *gorm.DB, r *http.Request, tokenType string, p *AuthMiddlewareParams) (database.User, database.Token, bool, error) {
|
||||
var user database.User
|
||||
var token database.Token
|
||||
|
||||
|
|
@ -208,9 +185,9 @@ type AuthMiddlewareParams struct {
|
|||
ProOnly bool
|
||||
}
|
||||
|
||||
func auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
|
||||
func (a *App) auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok, err := AuthWithSession(r, p)
|
||||
user, ok, err := AuthWithSession(a.DB, r, p)
|
||||
if !ok {
|
||||
respondUnauthorized(w)
|
||||
return
|
||||
|
|
@ -231,9 +208,9 @@ func auth(next http.HandlerFunc, p *AuthMiddlewareParams) http.HandlerFunc {
|
|||
})
|
||||
}
|
||||
|
||||
func tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams) http.HandlerFunc {
|
||||
func (a *App) tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, token, ok, err := authWithToken(r, tokenType, p)
|
||||
user, token, ok, err := authWithToken(a.DB, r, tokenType, p)
|
||||
if err != nil {
|
||||
// log the error and continue
|
||||
log.ErrorWrap(err, "authenticating with token")
|
||||
|
|
@ -245,7 +222,7 @@ func tokenAuth(next http.HandlerFunc, tokenType string, p *AuthMiddlewareParams)
|
|||
ctx = context.WithValue(ctx, helpers.KeyToken, token)
|
||||
} else {
|
||||
// If token-based auth fails, fall back to session-based auth
|
||||
user, ok, err = AuthWithSession(r, p)
|
||||
user, ok, err = AuthWithSession(a.DB, r, p)
|
||||
if err != nil {
|
||||
HandleError(w, "authenticating with session", err, http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -325,6 +302,7 @@ func applyMiddleware(h http.HandlerFunc, rateLimit bool) http.Handler {
|
|||
|
||||
// App is an application configuration
|
||||
type App struct {
|
||||
DB *gorm.DB
|
||||
Clock clock.Clock
|
||||
StripeAPIBackend stripe.Backend
|
||||
WebURL string
|
||||
|
|
@ -334,6 +312,9 @@ func (a *App) validate() error {
|
|||
if a.WebURL == "" {
|
||||
return errors.New("WebURL is empty")
|
||||
}
|
||||
if a.DB == nil {
|
||||
return errors.New("DB is empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -364,51 +345,50 @@ func NewRouter(app *App) (*mux.Router, error) {
|
|||
var routes = []Route{
|
||||
// internal
|
||||
{"GET", "/health", app.checkHealth, false},
|
||||
{"GET", "/me", auth(app.getMe, nil), true},
|
||||
{"POST", "/verification-token", auth(app.createVerificationToken, nil), true},
|
||||
{"GET", "/me", app.auth(app.getMe, nil), true},
|
||||
{"POST", "/verification-token", app.auth(app.createVerificationToken, nil), true},
|
||||
{"PATCH", "/verify-email", app.verifyEmail, true},
|
||||
{"POST", "/reset-token", app.createResetToken, true},
|
||||
{"PATCH", "/reset-password", app.resetPassword, true},
|
||||
{"PATCH", "/account/profile", auth(app.updateProfile, nil), true},
|
||||
{"PATCH", "/account/password", auth(app.updatePassword, nil), true},
|
||||
{"GET", "/account/email-preference", tokenAuth(app.getEmailPreference, database.TokenTypeEmailPreference, nil), true},
|
||||
{"PATCH", "/account/email-preference", tokenAuth(app.updateEmailPreference, database.TokenTypeEmailPreference, nil), true},
|
||||
{"POST", "/subscriptions", auth(app.createSub, nil), true},
|
||||
{"PATCH", "/subscriptions", auth(app.updateSub, nil), true},
|
||||
{"PATCH", "/account/profile", app.auth(app.updateProfile, nil), true},
|
||||
{"PATCH", "/account/password", app.auth(app.updatePassword, nil), true},
|
||||
{"GET", "/account/email-preference", app.tokenAuth(app.getEmailPreference, database.TokenTypeEmailPreference, nil), true},
|
||||
{"PATCH", "/account/email-preference", app.tokenAuth(app.updateEmailPreference, database.TokenTypeEmailPreference, nil), true},
|
||||
{"POST", "/subscriptions", app.auth(app.createSub, nil), true},
|
||||
{"PATCH", "/subscriptions", app.auth(app.updateSub, nil), true},
|
||||
{"POST", "/webhooks/stripe", app.stripeWebhook, true},
|
||||
{"GET", "/subscriptions", auth(app.getSub, nil), true},
|
||||
{"GET", "/stripe_source", auth(app.getStripeSource, nil), true},
|
||||
{"PATCH", "/stripe_source", auth(app.updateStripeSource, nil), true},
|
||||
{"GET", "/notes", auth(app.getNotes, &proOnly), false},
|
||||
{"GET", "/subscriptions", app.auth(app.getSub, nil), true},
|
||||
{"GET", "/stripe_source", app.auth(app.getStripeSource, nil), true},
|
||||
{"PATCH", "/stripe_source", app.auth(app.updateStripeSource, nil), true},
|
||||
{"GET", "/notes", app.auth(app.getNotes, &proOnly), false},
|
||||
{"GET", "/notes/{noteUUID}", app.getNote, true},
|
||||
{"GET", "/calendar", auth(app.getCalendar, &proOnly), true},
|
||||
{"GET", "/repetition_rules", auth(app.getRepetitionRules, &proOnly), true},
|
||||
{"GET", "/repetition_rules/{repetitionRuleUUID}", tokenAuth(app.getRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
|
||||
{"POST", "/repetition_rules", auth(app.createRepetitionRule, &proOnly), true},
|
||||
{"PATCH", "/repetition_rules/{repetitionRuleUUID}", tokenAuth(app.updateRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
|
||||
{"DELETE", "/repetition_rules/{repetitionRuleUUID}", auth(app.deleteRepetitionRule, &proOnly), true},
|
||||
{"GET", "/calendar", app.auth(app.getCalendar, &proOnly), true},
|
||||
{"GET", "/repetition_rules", app.auth(app.getRepetitionRules, &proOnly), true},
|
||||
{"GET", "/repetition_rules/{repetitionRuleUUID}", app.tokenAuth(app.getRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
|
||||
{"POST", "/repetition_rules", app.auth(app.createRepetitionRule, &proOnly), true},
|
||||
{"PATCH", "/repetition_rules/{repetitionRuleUUID}", app.tokenAuth(app.updateRepetitionRule, database.TokenTypeRepetition, &proOnly), true},
|
||||
{"DELETE", "/repetition_rules/{repetitionRuleUUID}", app.auth(app.deleteRepetitionRule, &proOnly), true},
|
||||
|
||||
// migration of classic users
|
||||
{"GET", "/classic/presignin", cors(app.classicPresignin), true},
|
||||
{"POST", "/classic/signin", cors(app.classicSignin), true},
|
||||
{"PATCH", "/classic/migrate", auth(app.classicMigrate, &proOnly), true},
|
||||
{"GET", "/classic/notes", auth(app.classicGetNotes, nil), true},
|
||||
{"PATCH", "/classic/set-password", auth(app.classicSetPassword, nil), true},
|
||||
{"PATCH", "/classic/migrate", app.auth(app.classicMigrate, &proOnly), true},
|
||||
{"GET", "/classic/notes", app.auth(app.classicGetNotes, nil), true},
|
||||
{"PATCH", "/classic/set-password", app.auth(app.classicSetPassword, nil), true},
|
||||
|
||||
// v3
|
||||
{"GET", "/v3/sync/fragment", cors(auth(app.GetSyncFragment, &proOnly)), true},
|
||||
{"GET", "/v3/sync/state", cors(auth(app.GetSyncState, &proOnly)), true},
|
||||
{"GET", "/v3/sync/fragment", cors(app.auth(app.GetSyncFragment, &proOnly)), true},
|
||||
{"GET", "/v3/sync/state", cors(app.auth(app.GetSyncState, &proOnly)), true},
|
||||
{"OPTIONS", "/v3/books", cors(app.BooksOptions), true},
|
||||
{"GET", "/v3/books", cors(auth(app.GetBooks, &proOnly)), true},
|
||||
{"GET", "/v3/books/{bookUUID}", cors(auth(app.GetBook, &proOnly)), true},
|
||||
{"POST", "/v3/books", cors(auth(app.CreateBook, &proOnly)), true},
|
||||
{"PATCH", "/v3/books/{bookUUID}", cors(auth(app.UpdateBook, &proOnly)), false},
|
||||
{"DELETE", "/v3/books/{bookUUID}", cors(auth(app.DeleteBook, &proOnly)), false},
|
||||
{"GET", "/v3/demo/books", app.GetDemoBooks, true},
|
||||
{"GET", "/v3/books", cors(app.auth(app.GetBooks, &proOnly)), true},
|
||||
{"GET", "/v3/books/{bookUUID}", cors(app.auth(app.GetBook, &proOnly)), true},
|
||||
{"POST", "/v3/books", cors(app.auth(app.CreateBook, &proOnly)), true},
|
||||
{"PATCH", "/v3/books/{bookUUID}", cors(app.auth(app.UpdateBook, &proOnly)), false},
|
||||
{"DELETE", "/v3/books/{bookUUID}", cors(app.auth(app.DeleteBook, &proOnly)), false},
|
||||
{"OPTIONS", "/v3/notes", cors(app.NotesOptions), true},
|
||||
{"POST", "/v3/notes", cors(auth(app.CreateNote, &proOnly)), true},
|
||||
{"PATCH", "/v3/notes/{noteUUID}", auth(app.UpdateNote, &proOnly), false},
|
||||
{"DELETE", "/v3/notes/{noteUUID}", auth(app.DeleteNote, &proOnly), false},
|
||||
{"POST", "/v3/notes", cors(app.auth(app.CreateNote, &proOnly)), true},
|
||||
{"PATCH", "/v3/notes/{noteUUID}", app.auth(app.UpdateNote, &proOnly), false},
|
||||
{"DELETE", "/v3/notes/{noteUUID}", app.auth(app.DeleteNote, &proOnly), false},
|
||||
{"POST", "/v3/signin", cors(app.signin), true},
|
||||
{"OPTIONS", "/v3/signout", cors(app.signoutOptions), true},
|
||||
{"POST", "/v3/signout", cors(app.signout), true},
|
||||
|
|
|
|||
|
|
@ -29,13 +29,10 @@ import (
|
|||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func TestGetSessionKeyFromCookie(t *testing.T) {
|
||||
testCases := []struct {
|
||||
cookie *http.Cookie
|
||||
|
|
@ -185,10 +182,8 @@ func TestGetCredential(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
defer testutils.ClearData()
|
||||
|
||||
// set up
|
||||
db := database.DBConn
|
||||
defer testutils.ClearData()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
session := database.Session{
|
||||
|
|
@ -196,18 +191,19 @@ func TestAuthMiddleware(t *testing.T) {
|
|||
UserID: user.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session), "preparing session")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
|
||||
session2 := database.Session{
|
||||
Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
|
||||
UserID: user.ID,
|
||||
ExpiresAt: time.Now().Add(-time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session2), "preparing session")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session")
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
server := httptest.NewServer(auth(handler, nil))
|
||||
app := App{DB: testutils.DB}
|
||||
server := httptest.NewServer(app.auth(handler, nil))
|
||||
defer server.Close()
|
||||
|
||||
t.Run("with header", func(t *testing.T) {
|
||||
|
|
@ -300,24 +296,23 @@ func TestAuthMiddleware(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthMiddleware_ProOnly(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
|
||||
// set up
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("cloud", false), "preparing session")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
|
||||
session := database.Session{
|
||||
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
|
||||
UserID: user.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session), "preparing session")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
server := httptest.NewServer(auth(handler, &AuthMiddlewareParams{
|
||||
app := App{DB: testutils.DB}
|
||||
server := httptest.NewServer(app.auth(handler, &AuthMiddlewareParams{
|
||||
ProOnly: true,
|
||||
}))
|
||||
defer server.Close()
|
||||
|
|
@ -390,10 +385,8 @@ func TestAuthMiddleware_ProOnly(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTokenAuthMiddleWare(t *testing.T) {
|
||||
defer testutils.ClearData()
|
||||
|
||||
// set up
|
||||
db := database.DBConn
|
||||
defer testutils.ClearData()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
tok := database.Token{
|
||||
|
|
@ -401,18 +394,19 @@ func TestTokenAuthMiddleWare(t *testing.T) {
|
|||
Type: database.TokenTypeEmailPreference,
|
||||
Value: "xpwFnc0MdllFUePDq9DLeQ==",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
session := database.Session{
|
||||
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
|
||||
UserID: user.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session), "preparing session")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
server := httptest.NewServer(tokenAuth(handler, database.TokenTypeEmailPreference, nil))
|
||||
app := App{DB: testutils.DB}
|
||||
server := httptest.NewServer(app.tokenAuth(handler, database.TokenTypeEmailPreference, nil))
|
||||
defer server.Close()
|
||||
|
||||
t.Run("with token", func(t *testing.T) {
|
||||
|
|
@ -521,30 +515,29 @@ func TestTokenAuthMiddleWare(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTokenAuthMiddleWare_ProOnly(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
|
||||
// set up
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("cloud", false), "preparing session")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
|
||||
tok := database.Token{
|
||||
UserID: user.ID,
|
||||
Type: database.TokenTypeEmailPreference,
|
||||
Value: "xpwFnc0MdllFUePDq9DLeQ==",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
session := database.Session{
|
||||
Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
|
||||
UserID: user.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session), "preparing session")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
|
||||
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
server := httptest.NewServer(tokenAuth(handler, database.TokenTypeEmailPreference, &AuthMiddlewareParams{
|
||||
app := App{DB: testutils.DB}
|
||||
server := httptest.NewServer(app.tokenAuth(handler, database.TokenTypeEmailPreference, &AuthMiddlewareParams{
|
||||
ProOnly: true,
|
||||
}))
|
||||
defer server.Close()
|
||||
|
|
@ -682,6 +675,7 @@ func TestNotSupportedVersions(t *testing.T) {
|
|||
|
||||
// setup
|
||||
server := MustNewServer(t, &App{
|
||||
DB: &gorm.DB{},
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stripe/stripe-go"
|
||||
|
|
@ -138,8 +138,7 @@ func (a *App) createSub(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
if err := tx.Model(&user).
|
||||
Update(map[string]interface{}{
|
||||
|
|
@ -431,8 +430,7 @@ func (a *App) updateStripeSource(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
if err := tx.Model(&user).
|
||||
Update(map[string]interface{}{
|
||||
|
|
@ -532,7 +530,7 @@ func (a *App) stripeWebhook(w http.ResponseWriter, req *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
operations.MarkUnsubscribed(subscription.Customer.ID)
|
||||
operations.MarkUnsubscribed(a.DB, subscription.Customer.ID)
|
||||
}
|
||||
default:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ import (
|
|||
// with the given app paratmers
|
||||
func MustNewServer(t *testing.T, app *App) *httptest.Server {
|
||||
app.WebURL = os.Getenv("WebURL")
|
||||
app.DB = testutils.DB
|
||||
|
||||
r, err := NewRouter(app)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -23,11 +23,12 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/log"
|
||||
"github.com/dnote/dnote/pkg/server/mailer"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
|
@ -38,8 +39,6 @@ type updateProfilePayload struct {
|
|||
|
||||
// updateProfile updates user
|
||||
func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
||||
if !ok {
|
||||
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
|
||||
|
|
@ -60,13 +59,13 @@ func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
err = db.Where("user_id = ?", user.ID).First(&account).Error
|
||||
err = a.DB.Where("user_id = ?", user.ID).First(&account).Error
|
||||
if err != nil {
|
||||
HandleError(w, "finding account", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
if err := tx.Save(&user).Error; err != nil {
|
||||
tx.Rollback()
|
||||
HandleError(w, "saving user", err, http.StatusInternalServerError)
|
||||
|
|
@ -87,7 +86,7 @@ func (a *App) updateProfile(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
tx.Commit()
|
||||
|
||||
respondWithSession(w, user.ID, http.StatusOK)
|
||||
respondWithSession(a.DB, w, user.ID, http.StatusOK)
|
||||
}
|
||||
|
||||
type updateEmailPayload struct {
|
||||
|
|
@ -97,9 +96,7 @@ type updateEmailPayload struct {
|
|||
NewAuthKey string `json:"new_auth_key"`
|
||||
}
|
||||
|
||||
func respondWithCalendar(w http.ResponseWriter, userID int) {
|
||||
db := database.DBConn
|
||||
|
||||
func respondWithCalendar(db *gorm.DB, w http.ResponseWriter, userID int) {
|
||||
rows, err := db.Table("notes").Select("COUNT(id), date(to_timestamp(added_on/1000000000)) AS added_date").
|
||||
Where("user_id = ?", userID).
|
||||
Group("added_date").
|
||||
|
|
@ -132,22 +129,10 @@ func (a *App) getCalendar(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
respondWithCalendar(w, user.ID)
|
||||
}
|
||||
|
||||
func (a *App) getDemoCalendar(w http.ResponseWriter, r *http.Request) {
|
||||
userID, err := helpers.GetDemoUserID()
|
||||
if err != nil {
|
||||
HandleError(w, "finding demo user", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
respondWithCalendar(w, userID)
|
||||
respondWithCalendar(a.DB, w, user.ID)
|
||||
}
|
||||
|
||||
func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
||||
if !ok {
|
||||
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
|
||||
|
|
@ -155,7 +140,7 @@ func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
err := db.Where("user_id = ?", user.ID).First(&account).Error
|
||||
err := a.DB.Where("user_id = ?", user.ID).First(&account).Error
|
||||
if err != nil {
|
||||
HandleError(w, "finding account", err, http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -182,7 +167,7 @@ func (a *App) createVerificationToken(w http.ResponseWriter, r *http.Request) {
|
|||
Type: database.TokenTypeEmailVerification,
|
||||
}
|
||||
|
||||
if err := db.Save(&token).Error; err != nil {
|
||||
if err := a.DB.Save(&token).Error; err != nil {
|
||||
HandleError(w, "saving token", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -212,8 +197,6 @@ type verifyEmailPayload struct {
|
|||
}
|
||||
|
||||
func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
var params verifyEmailPayload
|
||||
if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil {
|
||||
HandleError(w, "decoding payload", err, http.StatusInternalServerError)
|
||||
|
|
@ -221,7 +204,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var token database.Token
|
||||
if err := db.
|
||||
if err := a.DB.
|
||||
Where("value = ? AND type = ?", params.Token, database.TokenTypeEmailVerification).
|
||||
First(&token).Error; err != nil {
|
||||
http.Error(w, "invalid token", http.StatusBadRequest)
|
||||
|
|
@ -240,7 +223,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
if err := db.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ?", token.UserID).First(&account).Error; err != nil {
|
||||
HandleError(w, "finding account", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -249,7 +232,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
account.EmailVerified = true
|
||||
if err := tx.Save(&account).Error; err != nil {
|
||||
tx.Rollback()
|
||||
|
|
@ -264,7 +247,7 @@ func (a *App) verifyEmail(w http.ResponseWriter, r *http.Request) {
|
|||
tx.Commit()
|
||||
|
||||
var user database.User
|
||||
if err := db.Where("id = ?", token.UserID).First(&user).Error; err != nil {
|
||||
if err := a.DB.Where("id = ?", token.UserID).First(&user).Error; err != nil {
|
||||
HandleError(w, "finding user", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -278,8 +261,6 @@ type updateEmailPreferencePayload struct {
|
|||
}
|
||||
|
||||
func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
||||
if !ok {
|
||||
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
|
||||
|
|
@ -293,12 +274,12 @@ func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var frequency database.EmailPreference
|
||||
if err := db.Where(database.EmailPreference{UserID: user.ID}).FirstOrCreate(&frequency).Error; err != nil {
|
||||
if err := a.DB.Where(database.EmailPreference{UserID: user.ID}).FirstOrCreate(&frequency).Error; err != nil {
|
||||
HandleError(w, "finding frequency", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
frequency.DigestWeekly = params.DigestWeekly
|
||||
if err := tx.Save(&frequency).Error; err != nil {
|
||||
|
|
@ -323,8 +304,6 @@ func (a *App) updateEmailPreference(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (a *App) getEmailPreference(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
||||
if !ok {
|
||||
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
|
||||
|
|
@ -332,7 +311,7 @@ func (a *App) getEmailPreference(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var pref database.EmailPreference
|
||||
if err := db.Where(database.EmailPreference{UserID: user.ID}).First(&pref).Error; err != nil {
|
||||
if err := a.DB.Where(database.EmailPreference{UserID: user.ID}).First(&pref).Error; err != nil {
|
||||
HandleError(w, "finding pref", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -347,8 +326,6 @@ type updatePasswordPayload struct {
|
|||
}
|
||||
|
||||
func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
||||
if !ok {
|
||||
HandleError(w, "No authenticated user found", nil, http.StatusInternalServerError)
|
||||
|
|
@ -366,7 +343,7 @@ func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
HandleError(w, "getting user", nil, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -391,7 +368,7 @@ func (a *App) updatePassword(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if err := db.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
|
||||
if err := a.DB.Model(&account).Update("password", string(hashedNewPassword)).Error; err != nil {
|
||||
http.Error(w, errors.Wrap(err, "updating password").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,18 +36,14 @@ import (
|
|||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
|
||||
}
|
||||
|
||||
func TestUpdatePassword(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -64,18 +60,19 @@ func TestUpdatePassword(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismsatch")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
|
||||
assert.Equal(t, passwordErr, nil, "Password mismatch")
|
||||
})
|
||||
|
||||
t.Run("old password mismatch", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -92,16 +89,17 @@ func TestUpdatePassword(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
|
||||
})
|
||||
|
||||
t.Run("password too short", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -118,15 +116,15 @@ func TestUpdatePassword(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateVerificationToken(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
|
||||
|
|
@ -137,6 +135,7 @@ func TestCreateVerificationToken(t *testing.T) {
|
|||
mailer.InitTemplates(&templatePath)
|
||||
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -154,9 +153,9 @@ func TestCreateVerificationToken(t *testing.T) {
|
|||
var account database.Account
|
||||
var token database.Token
|
||||
var tokenCount int
|
||||
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
|
||||
assert.Equal(t, account.EmailVerified, false, "email_verified should not have been updated")
|
||||
assert.NotEqual(t, token.Value, "", "token Value mismatch")
|
||||
|
|
@ -165,11 +164,12 @@ func TestCreateVerificationToken(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("already verified", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -177,7 +177,7 @@ func TestCreateVerificationToken(t *testing.T) {
|
|||
user := testutils.SetupUserData()
|
||||
a := testutils.SetupAccountData(user, "alice@example.com", "pass1234")
|
||||
a.EmailVerified = true
|
||||
testutils.MustExec(t, db.Save(&a), "preparing account")
|
||||
testutils.MustExec(t, testutils.DB.Save(&a), "preparing account")
|
||||
|
||||
// Execute
|
||||
req := testutils.MakeReq(server, "POST", "/verification-token", "")
|
||||
|
|
@ -188,8 +188,8 @@ func TestCreateVerificationToken(t *testing.T) {
|
|||
|
||||
var account database.Account
|
||||
var tokenCount int
|
||||
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
|
||||
assert.Equal(t, account.EmailVerified, true, "email_verified should not have been updated")
|
||||
assert.Equal(t, tokenCount, 0, "token count mismatch")
|
||||
|
|
@ -198,11 +198,12 @@ func TestCreateVerificationToken(t *testing.T) {
|
|||
|
||||
func TestVerifyEmail(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -214,7 +215,7 @@ func TestVerifyEmail(t *testing.T) {
|
|||
Type: database.TokenTypeEmailVerification,
|
||||
Value: "someTokenValue",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
dat := `{"token": "someTokenValue"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
|
||||
|
|
@ -228,9 +229,9 @@ func TestVerifyEmail(t *testing.T) {
|
|||
var account database.Account
|
||||
var token database.Token
|
||||
var tokenCount int
|
||||
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
|
||||
assert.Equal(t, account.EmailVerified, true, "email_verified mismatch")
|
||||
assert.NotEqual(t, token.Value, "", "token value should not have been updated")
|
||||
|
|
@ -239,11 +240,12 @@ func TestVerifyEmail(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("used token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -258,7 +260,7 @@ func TestVerifyEmail(t *testing.T) {
|
|||
Value: "someTokenValue",
|
||||
UsedAt: &usedAt,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
dat := `{"token": "someTokenValue"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
|
||||
|
|
@ -272,9 +274,9 @@ func TestVerifyEmail(t *testing.T) {
|
|||
var account database.Account
|
||||
var token database.Token
|
||||
var tokenCount int
|
||||
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
|
||||
assert.Equal(t, account.EmailVerified, false, "email_verified mismatch")
|
||||
assert.NotEqual(t, token.UsedAt, nil, "token used_at mismatch")
|
||||
|
|
@ -283,11 +285,12 @@ func TestVerifyEmail(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -300,8 +303,8 @@ func TestVerifyEmail(t *testing.T) {
|
|||
Type: database.TokenTypeEmailVerification,
|
||||
Value: "someTokenValue",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at")
|
||||
|
||||
dat := `{"token": "someTokenValue"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
|
||||
|
|
@ -315,9 +318,9 @@ func TestVerifyEmail(t *testing.T) {
|
|||
var account database.Account
|
||||
var token database.Token
|
||||
var tokenCount int
|
||||
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
|
||||
assert.Equal(t, account.EmailVerified, false, "email_verified mismatch")
|
||||
assert.Equal(t, tokenCount, 1, "token count mismatch")
|
||||
|
|
@ -325,11 +328,12 @@ func TestVerifyEmail(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("already verified", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -337,14 +341,14 @@ func TestVerifyEmail(t *testing.T) {
|
|||
user := testutils.SetupUserData()
|
||||
a := testutils.SetupAccountData(user, "alice@example.com", "oldpass1234")
|
||||
a.EmailVerified = true
|
||||
testutils.MustExec(t, db.Save(&a), "preparing account")
|
||||
testutils.MustExec(t, testutils.DB.Save(&a), "preparing account")
|
||||
|
||||
tok := database.Token{
|
||||
UserID: user.ID,
|
||||
Type: database.TokenTypeEmailVerification,
|
||||
Value: "someTokenValue",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
dat := `{"token": "someTokenValue"}`
|
||||
req := testutils.MakeReq(server, "PATCH", "/verify-email", dat)
|
||||
|
|
@ -358,9 +362,9 @@ func TestVerifyEmail(t *testing.T) {
|
|||
var account database.Account
|
||||
var token database.Token
|
||||
var tokenCount int
|
||||
testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
|
||||
|
||||
assert.Equal(t, account.EmailVerified, true, "email_verified mismatch")
|
||||
assert.Equal(t, tokenCount, 1, "token count mismatch")
|
||||
|
|
@ -370,11 +374,12 @@ func TestVerifyEmail(t *testing.T) {
|
|||
|
||||
func TestUpdateEmail(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -382,7 +387,7 @@ func TestUpdateEmail(t *testing.T) {
|
|||
u := testutils.SetupUserData()
|
||||
a := testutils.SetupAccountData(u, "alice@example.com", "pass1234")
|
||||
a.EmailVerified = true
|
||||
testutils.MustExec(t, db.Save(&a), "updating email_verified")
|
||||
testutils.MustExec(t, testutils.DB.Save(&a), "updating email_verified")
|
||||
|
||||
// Execute
|
||||
dat := `{"email": "alice-new@example.com"}`
|
||||
|
|
@ -394,8 +399,8 @@ func TestUpdateEmail(t *testing.T) {
|
|||
|
||||
var user database.User
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
|
||||
|
||||
assert.Equal(t, account.Email.String, "alice-new@example.com", "email mismatch")
|
||||
assert.Equal(t, account.EmailVerified, false, "EmailVerified mismatch")
|
||||
|
|
@ -404,11 +409,12 @@ func TestUpdateEmail(t *testing.T) {
|
|||
|
||||
func TestUpdateEmailPreference(t *testing.T) {
|
||||
t.Run("with login", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -425,16 +431,17 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "")
|
||||
|
||||
var preference database.EmailPreference
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding account")
|
||||
assert.Equal(t, preference.DigestWeekly, true, "preference mismatch")
|
||||
})
|
||||
|
||||
t.Run("with an unused token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -446,7 +453,7 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
Type: database.TokenTypeEmailPreference,
|
||||
Value: "someTokenValue",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
// Execute
|
||||
dat := `{"digest_weekly": true}`
|
||||
|
|
@ -460,9 +467,9 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
var preference database.EmailPreference
|
||||
var preferenceCount int
|
||||
var token database.Token
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
testutils.MustExec(t, db.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
|
||||
testutils.MustExec(t, db.Where("id = ?", tok.ID).First(&token), "failed to find token")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
testutils.MustExec(t, testutils.DB.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", tok.ID).First(&token), "failed to find token")
|
||||
|
||||
assert.Equal(t, preferenceCount, 1, "preference count mismatch")
|
||||
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
|
||||
|
|
@ -470,11 +477,12 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("with nonexistent token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -486,7 +494,7 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
Type: database.TokenTypeEmailPreference,
|
||||
Value: "someTokenValue",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
dat := `{"digest_weekly": false}`
|
||||
url := fmt.Sprintf("/account/email-preference?token=%s", "someNonexistentToken")
|
||||
|
|
@ -499,16 +507,17 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
|
||||
|
||||
var preference database.EmailPreference
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
|
||||
})
|
||||
|
||||
t.Run("with expired token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -523,7 +532,7 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
Value: "someTokenValue",
|
||||
UsedAt: &usedAt,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
// Execute
|
||||
dat := `{"digest_weekly": false}`
|
||||
|
|
@ -535,16 +544,17 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
|
||||
|
||||
var preference database.EmailPreference
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
|
||||
})
|
||||
|
||||
t.Run("with a used but unexpired token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -558,7 +568,7 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
Value: "someTokenValue",
|
||||
UsedAt: &usedAt,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
dat := `{"digest_weekly": false}`
|
||||
url := fmt.Sprintf("/account/email-preference?token=%s", "someTokenValue")
|
||||
|
|
@ -571,16 +581,17 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "")
|
||||
|
||||
var preference database.EmailPreference
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
assert.Equal(t, preference.DigestWeekly, false, "DigestWeekly mismatch")
|
||||
})
|
||||
|
||||
t.Run("no user and no token", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -597,16 +608,17 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
|
||||
|
||||
var preference database.EmailPreference
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
assert.Equal(t, preference.DigestWeekly, true, "email mismatch")
|
||||
})
|
||||
|
||||
t.Run("create a record if not exists", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -617,7 +629,7 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
Type: database.TokenTypeEmailPreference,
|
||||
Value: "someTokenValue",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&tok), "preparing token")
|
||||
testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
|
||||
|
||||
// Execute
|
||||
dat := `{"digest_weekly": false}`
|
||||
|
|
@ -629,20 +641,21 @@ func TestUpdateEmailPreference(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "")
|
||||
|
||||
var preferenceCount int
|
||||
testutils.MustExec(t, db.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
|
||||
testutils.MustExec(t, testutils.DB.Model(database.EmailPreference{}).Count(&preferenceCount), "counting preference")
|
||||
assert.Equal(t, preferenceCount, 1, "preference count mismatch")
|
||||
|
||||
var preference database.EmailPreference
|
||||
testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&preference), "finding preference")
|
||||
assert.Equal(t, preference.DigestWeekly, false, "email mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetEmailPreference(t *testing.T) {
|
||||
defer testutils.ClearData()
|
||||
|
||||
defer testutils.ClearData()
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
|
|||
|
|
@ -23,8 +23,9 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
|
@ -63,9 +64,7 @@ func unsetSessionCookie(w http.ResponseWriter) {
|
|||
http.SetCookie(w, &cookie)
|
||||
}
|
||||
|
||||
func touchLastLoginAt(user database.User) error {
|
||||
db := database.DBConn
|
||||
|
||||
func touchLastLoginAt(db *gorm.DB, user database.User) error {
|
||||
t := time.Now()
|
||||
if err := db.Model(&user).Update(database.User{LastLoginAt: &t}).Error; err != nil {
|
||||
return errors.Wrap(err, "updating last_login_at")
|
||||
|
|
@ -80,8 +79,6 @@ type signinPayload struct {
|
|||
}
|
||||
|
||||
func (a *App) signin(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
var params signinPayload
|
||||
err := json.NewDecoder(r.Body).Decode(¶ms)
|
||||
if err != nil {
|
||||
|
|
@ -94,7 +91,7 @@ func (a *App) signin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var account database.Account
|
||||
conn := db.Where("email = ?", params.Email).First(&account)
|
||||
conn := a.DB.Where("email = ?", params.Email).First(&account)
|
||||
if conn.RecordNotFound() {
|
||||
http.Error(w, ErrLoginFailure.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
|
|
@ -111,19 +108,19 @@ func (a *App) signin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var user database.User
|
||||
err = db.Where("id = ?", account.UserID).First(&user).Error
|
||||
err = a.DB.Where("id = ?", account.UserID).First(&user).Error
|
||||
if err != nil {
|
||||
HandleError(w, "finding user", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = operations.TouchLastLoginAt(user, db)
|
||||
err = operations.TouchLastLoginAt(user, a.DB)
|
||||
if err != nil {
|
||||
http.Error(w, errors.Wrap(err, "touching login timestamp").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
respondWithSession(w, account.UserID, http.StatusOK)
|
||||
respondWithSession(a.DB, w, account.UserID, http.StatusOK)
|
||||
}
|
||||
|
||||
func (a *App) signoutOptions(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -143,7 +140,7 @@ func (a *App) signout(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
err = operations.DeleteSession(database.DBConn, key)
|
||||
err = operations.DeleteSession(a.DB, key)
|
||||
if err != nil {
|
||||
HandleError(w, "deleting session", nil, http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -182,8 +179,6 @@ func parseRegisterPaylaod(r *http.Request) (registerPayload, bool) {
|
|||
}
|
||||
|
||||
func (a *App) register(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
params, ok := parseRegisterPaylaod(r)
|
||||
if !ok {
|
||||
http.Error(w, "invalid payload", http.StatusBadRequest)
|
||||
|
|
@ -191,7 +186,7 @@ func (a *App) register(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var count int
|
||||
if err := db.Model(database.Account{}).Where("email = ?", params.Email).Count(&count).Error; err != nil {
|
||||
if err := a.DB.Model(database.Account{}).Where("email = ?", params.Email).Count(&count).Error; err != nil {
|
||||
HandleError(w, "checking duplicate user", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
|
@ -200,20 +195,18 @@ func (a *App) register(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
user, err := operations.CreateUser(params.Email, params.Password)
|
||||
user, err := operations.CreateUser(a.DB, params.Email, params.Password)
|
||||
if err != nil {
|
||||
HandleError(w, "creating user", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
respondWithSession(w, user.ID, http.StatusCreated)
|
||||
respondWithSession(a.DB, w, user.ID, http.StatusCreated)
|
||||
}
|
||||
|
||||
// respondWithSession makes a HTTP response with the session from the user with the given userID.
|
||||
// It sets the HTTP-Only cookie for browser clients and also sends a JSON response for non-browser clients.
|
||||
func respondWithSession(w http.ResponseWriter, userID int, statusCode int) {
|
||||
db := database.DBConn
|
||||
|
||||
func respondWithSession(db *gorm.DB, w http.ResponseWriter, userID int, statusCode int) {
|
||||
session, err := operations.CreateSession(db, userID)
|
||||
if err != nil {
|
||||
HandleError(w, "creating session", nil, http.StatusBadRequest)
|
||||
|
|
|
|||
|
|
@ -22,26 +22,17 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/assert"
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/mailer"
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
|
||||
templatePath := os.Getenv("DNOTE_TEST_EMAIL_TEMPLATE_DIR")
|
||||
mailer.InitTemplates(&templatePath)
|
||||
}
|
||||
|
||||
func assertSessionResp(t *testing.T, res *http.Response) {
|
||||
// after register, should sign in user
|
||||
var got SessionResponse
|
||||
|
|
@ -51,9 +42,8 @@ func assertSessionResp(t *testing.T, res *http.Response) {
|
|||
|
||||
var sessionCount int
|
||||
var session database.Session
|
||||
db := database.DBConn
|
||||
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, db.First(&session), "getting session")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, testutils.DB.First(&session), "getting session")
|
||||
|
||||
assert.Equal(t, sessionCount, 1, "sessionCount mismatch")
|
||||
assert.Equal(t, got.Key, session.Key, "session Key mismatch")
|
||||
|
|
@ -87,11 +77,12 @@ func TestRegister(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("register %s %s", tc.email, tc.password), func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -106,20 +97,20 @@ func TestRegister(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
|
||||
|
||||
var account database.Account
|
||||
testutils.MustExec(t, db.Where("email = ?", tc.email).First(&account), "finding account")
|
||||
testutils.MustExec(t, testutils.DB.Where("email = ?", tc.email).First(&account), "finding account")
|
||||
assert.Equal(t, account.Email.String, tc.email, "Email mismatch")
|
||||
assert.NotEqual(t, account.UserID, 0, "UserID mismatch")
|
||||
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password))
|
||||
assert.Equal(t, passwordErr, nil, "Password mismatch")
|
||||
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", account.UserID).First(&user), "finding user")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", account.UserID).First(&user), "finding user")
|
||||
assert.Equal(t, user.Cloud, false, "Cloud mismatch")
|
||||
assert.Equal(t, user.StripeCustomerID, "", "StripeCustomerID mismatch")
|
||||
assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch")
|
||||
|
||||
var repetitionRuleCount int
|
||||
testutils.MustExec(t, db.Model(&database.RepetitionRule{}).Where("user_id = ?", account.UserID).Count(&repetitionRuleCount), "counting repetition rules")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.RepetitionRule{}).Where("user_id = ?", account.UserID).Count(&repetitionRuleCount), "counting repetition rules")
|
||||
assert.Equal(t, repetitionRuleCount, 1, "repetitionRuleCount mismatch")
|
||||
|
||||
// after register, should sign in user
|
||||
|
|
@ -130,11 +121,12 @@ func TestRegister(t *testing.T) {
|
|||
|
||||
func TestRegisterMissingParams(t *testing.T) {
|
||||
t.Run("missing email", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -149,19 +141,20 @@ func TestRegisterMissingParams(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
|
||||
|
||||
var accountCount, userCount int
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
|
||||
assert.Equal(t, accountCount, 0, "accountCount mismatch")
|
||||
assert.Equal(t, userCount, 0, "userCount mismatch")
|
||||
})
|
||||
|
||||
t.Run("missing password", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -176,8 +169,8 @@ func TestRegisterMissingParams(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
|
||||
|
||||
var accountCount, userCount int
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
|
||||
assert.Equal(t, accountCount, 0, "accountCount mismatch")
|
||||
assert.Equal(t, userCount, 0, "userCount mismatch")
|
||||
|
|
@ -185,11 +178,12 @@ func TestRegisterMissingParams(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRegisterDuplicateEmail(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -207,12 +201,12 @@ func TestRegisterDuplicateEmail(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch")
|
||||
|
||||
var accountCount, userCount, verificationTokenCount int
|
||||
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
testutils.MustExec(t, db.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
|
||||
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
|
||||
|
||||
assert.Equal(t, accountCount, 1, "account count mismatch")
|
||||
assert.Equal(t, userCount, 1, "user count mismatch")
|
||||
|
|
@ -222,11 +216,12 @@ func TestRegisterDuplicateEmail(t *testing.T) {
|
|||
|
||||
func TestSignIn(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -244,7 +239,7 @@ func TestSignIn(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusOK, "")
|
||||
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
|
||||
assert.NotEqual(t, user.LastLoginAt, nil, "LastLoginAt mismatch")
|
||||
|
||||
// after register, should sign in user
|
||||
|
|
@ -252,11 +247,12 @@ func TestSignIn(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("wrong password", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -274,20 +270,21 @@ func TestSignIn(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
|
||||
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
|
||||
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
|
||||
|
||||
var sessionCount int
|
||||
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
|
||||
})
|
||||
|
||||
t.Run("wrong email", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -305,20 +302,21 @@ func TestSignIn(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
|
||||
|
||||
var user database.User
|
||||
testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
|
||||
assert.DeepEqual(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
|
||||
|
||||
var sessionCount int
|
||||
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
|
||||
})
|
||||
|
||||
t.Run("nonexistent email", func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -333,14 +331,14 @@ func TestSignIn(t *testing.T) {
|
|||
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
|
||||
|
||||
var sessionCount int
|
||||
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
assert.Equal(t, sessionCount, 0, "sessionCount mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignout(t *testing.T) {
|
||||
t.Run("authenticated", func(t *testing.T) {
|
||||
db := database.DBConn
|
||||
|
||||
defer testutils.ClearData()
|
||||
|
||||
aliceUser := testutils.SetupUserData()
|
||||
|
|
@ -352,16 +350,17 @@ func TestSignout(t *testing.T) {
|
|||
UserID: aliceUser.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session1), "preparing session1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session1), "preparing session1")
|
||||
session2 := database.Session{
|
||||
Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=",
|
||||
UserID: anotherUser.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session2), "preparing session2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session2")
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -376,8 +375,8 @@ func TestSignout(t *testing.T) {
|
|||
|
||||
var sessionCount int
|
||||
var s2 database.Session
|
||||
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2")
|
||||
|
||||
assert.Equal(t, sessionCount, 1, "sessionCount mismatch")
|
||||
|
||||
|
|
@ -391,7 +390,7 @@ func TestSignout(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("unauthenticated", func(t *testing.T) {
|
||||
db := database.DBConn
|
||||
|
||||
defer testutils.ClearData()
|
||||
|
||||
aliceUser := testutils.SetupUserData()
|
||||
|
|
@ -403,16 +402,17 @@ func TestSignout(t *testing.T) {
|
|||
UserID: aliceUser.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session1), "preparing session1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session1), "preparing session1")
|
||||
session2 := database.Session{
|
||||
Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=",
|
||||
UserID: anotherUser.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&session2), "preparing session2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session2")
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
|
@ -426,9 +426,9 @@ func TestSignout(t *testing.T) {
|
|||
|
||||
var sessionCount int
|
||||
var postSession1, postSession2 database.Session
|
||||
testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, db.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1")
|
||||
testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
|
||||
testutils.MustExec(t, testutils.DB.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1")
|
||||
testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2")
|
||||
|
||||
// two existing sessions should remain
|
||||
assert.Equal(t, sessionCount, 2, "sessionCount mismatch")
|
||||
|
|
|
|||
|
|
@ -24,11 +24,12 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
|
@ -69,10 +70,8 @@ func (a *App) CreateBook(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
var bookCount int
|
||||
err = db.Model(database.Book{}).
|
||||
err = a.DB.Model(database.Book{}).
|
||||
Where("user_id = ? AND label = ?", user.ID, params.Name).
|
||||
Count(&bookCount).Error
|
||||
if err != nil {
|
||||
|
|
@ -84,7 +83,7 @@ func (a *App) CreateBook(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
book, err := operations.CreateBook(user, a.Clock, params.Name)
|
||||
book, err := operations.CreateBook(a.DB, user, a.Clock, params.Name)
|
||||
if err != nil {
|
||||
HandleError(w, "inserting book", err, http.StatusInternalServerError)
|
||||
}
|
||||
|
|
@ -100,9 +99,7 @@ func (a *App) BooksOptions(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Version")
|
||||
}
|
||||
|
||||
func respondWithBooks(userID int, query url.Values, w http.ResponseWriter) {
|
||||
db := database.DBConn
|
||||
|
||||
func respondWithBooks(db *gorm.DB, userID int, query url.Values, w http.ResponseWriter) {
|
||||
var books []database.Book
|
||||
conn := db.Where("user_id = ? AND NOT deleted", userID).Order("label ASC")
|
||||
name := query.Get("name")
|
||||
|
|
@ -132,19 +129,6 @@ func respondWithBooks(userID int, query url.Values, w http.ResponseWriter) {
|
|||
respondJSON(w, http.StatusOK, presentedBooks)
|
||||
}
|
||||
|
||||
// GetDemoBooks returns books for demo
|
||||
func (a *App) GetDemoBooks(w http.ResponseWriter, r *http.Request) {
|
||||
demoUserID, err := helpers.GetDemoUserID()
|
||||
if err != nil {
|
||||
HandleError(w, "finding demo user", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
query := r.URL.Query()
|
||||
|
||||
respondWithBooks(demoUserID, query, w)
|
||||
}
|
||||
|
||||
// GetBooks returns books for the user
|
||||
func (a *App) GetBooks(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
||||
|
|
@ -154,7 +138,7 @@ func (a *App) GetBooks(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
query := r.URL.Query()
|
||||
|
||||
respondWithBooks(user.ID, query, w)
|
||||
respondWithBooks(a.DB, user.ID, query, w)
|
||||
}
|
||||
|
||||
// GetBook returns a book for the user
|
||||
|
|
@ -164,13 +148,11 @@ func (a *App) GetBook(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
vars := mux.Vars(r)
|
||||
bookUUID := vars["bookUUID"]
|
||||
|
||||
var book database.Book
|
||||
conn := db.Where("uuid = ? AND user_id = ?", bookUUID, user.ID).First(&book)
|
||||
conn := a.DB.Where("uuid = ? AND user_id = ?", bookUUID, user.ID).First(&book)
|
||||
|
||||
if conn.RecordNotFound() {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
|
|
@ -204,8 +186,7 @@ func (a *App) UpdateBook(w http.ResponseWriter, r *http.Request) {
|
|||
vars := mux.Vars(r)
|
||||
uuid := vars["bookUUID"]
|
||||
|
||||
db := database.DBConn
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
var book database.Book
|
||||
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, uuid).First(&book).Error; err != nil {
|
||||
|
|
@ -250,8 +231,7 @@ func (a *App) DeleteBook(w http.ResponseWriter, r *http.Request) {
|
|||
vars := mux.Vars(r)
|
||||
uuid := vars["bookUUID"]
|
||||
|
||||
db := database.DBConn
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
var book database.Book
|
||||
if err := tx.Where("user_id = ? AND uuid = ?", user.ID, uuid).First(&book).Error; err != nil {
|
||||
|
|
|
|||
|
|
@ -26,23 +26,20 @@ import (
|
|||
|
||||
"github.com/dnote/dnote/pkg/assert"
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func TestGetBooks(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -55,28 +52,28 @@ func TestGetBooks(t *testing.T) {
|
|||
USN: 1123,
|
||||
Deleted: false,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
b2 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "css",
|
||||
USN: 1125,
|
||||
Deleted: false,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b2), "preparing b2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
|
||||
b3 := database.Book{
|
||||
UserID: anotherUser.ID,
|
||||
Label: "css",
|
||||
USN: 1128,
|
||||
Deleted: false,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b3), "preparing b3")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
|
||||
b4 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "",
|
||||
USN: 1129,
|
||||
Deleted: true,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b4), "preparing b4")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b4), "preparing b4")
|
||||
|
||||
// Execute
|
||||
req := testutils.MakeReq(server, "GET", "/v3/books", "")
|
||||
|
|
@ -91,9 +88,9 @@ func TestGetBooks(t *testing.T) {
|
|||
}
|
||||
|
||||
var b1Record, b2Record database.Book
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
|
||||
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
|
||||
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
|
||||
|
||||
expected := []presenters.Book{
|
||||
{
|
||||
|
|
@ -116,12 +113,13 @@ func TestGetBooks(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetBooksByName(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -133,17 +131,17 @@ func TestGetBooksByName(t *testing.T) {
|
|||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
b2 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "css",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b2), "preparing b2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
|
||||
b3 := database.Book{
|
||||
UserID: anotherUser.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b3), "preparing b3")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
|
||||
|
||||
// Execute
|
||||
res := testutils.HTTPAuthDo(t, req, user)
|
||||
|
|
@ -157,7 +155,7 @@ func TestGetBooksByName(t *testing.T) {
|
|||
}
|
||||
|
||||
var b1Record database.Book
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
|
||||
|
||||
expected := []presenters.Book{
|
||||
{
|
||||
|
|
@ -201,39 +199,40 @@ func TestDeleteBook(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 58), "preparing user max_usn")
|
||||
anotherUser := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
|
||||
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
USN: 1,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing a book data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing a book data")
|
||||
b2 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: tc.label,
|
||||
USN: 2,
|
||||
Deleted: tc.deleted,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b2), "preparing a book data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing a book data")
|
||||
b3 := database.Book{
|
||||
UserID: anotherUser.ID,
|
||||
Label: "linux",
|
||||
USN: 3,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b3), "preparing a book data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing a book data")
|
||||
|
||||
var n2Body string
|
||||
if !tc.deleted {
|
||||
|
|
@ -250,7 +249,7 @@ func TestDeleteBook(t *testing.T) {
|
|||
Body: "n1 content",
|
||||
USN: 4,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n1), "preparing a note data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing a note data")
|
||||
n2 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b2.UUID,
|
||||
|
|
@ -258,7 +257,7 @@ func TestDeleteBook(t *testing.T) {
|
|||
USN: 5,
|
||||
Deleted: tc.deleted,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n2), "preparing a note data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing a note data")
|
||||
n3 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b2.UUID,
|
||||
|
|
@ -266,7 +265,7 @@ func TestDeleteBook(t *testing.T) {
|
|||
USN: 6,
|
||||
Deleted: tc.deleted,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n3), "preparing a note data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing a note data")
|
||||
n4 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b2.UUID,
|
||||
|
|
@ -274,14 +273,14 @@ func TestDeleteBook(t *testing.T) {
|
|||
USN: 7,
|
||||
Deleted: true,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n4), "preparing a note data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n4), "preparing a note data")
|
||||
n5 := database.Note{
|
||||
UserID: anotherUser.ID,
|
||||
BookUUID: b3.UUID,
|
||||
Body: "n5 content",
|
||||
USN: 8,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n5), "preparing a note data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n5), "preparing a note data")
|
||||
|
||||
endpoint := fmt.Sprintf("/v3/books/%s", b2.UUID)
|
||||
req := testutils.MakeReq(server, "DELETE", endpoint, "")
|
||||
|
|
@ -299,17 +298,17 @@ func TestDeleteBook(t *testing.T) {
|
|||
var userRecord database.User
|
||||
var bookCount, noteCount int
|
||||
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
|
||||
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
|
||||
testutils.MustExec(t, db.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
|
||||
testutils.MustExec(t, db.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
|
||||
testutils.MustExec(t, db.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, db.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
|
||||
testutils.MustExec(t, db.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
|
||||
testutils.MustExec(t, db.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
|
||||
assert.Equal(t, bookCount, 3, "book count mismatch")
|
||||
assert.Equal(t, noteCount, 5, "note count mismatch")
|
||||
|
|
@ -351,17 +350,18 @@ func TestDeleteBook(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateBook(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
req := testutils.MakeReq(server, "POST", "/v3/books", `{"name": "js"}`)
|
||||
req.Header.Set("Version", "0.1.1")
|
||||
|
|
@ -376,10 +376,10 @@ func TestCreateBook(t *testing.T) {
|
|||
var bookRecord database.Book
|
||||
var userRecord database.User
|
||||
var bookCount, noteCount int
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, db.First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
|
||||
maxUSN := 102
|
||||
|
||||
|
|
@ -410,24 +410,25 @@ func TestCreateBook(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateBookDuplicate(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
USN: 58,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing book data")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book data")
|
||||
|
||||
// Execute
|
||||
req := testutils.MakeReq(server, "POST", "/v3/books", `{"name": "js"}`)
|
||||
|
|
@ -439,10 +440,10 @@ func TestCreateBookDuplicate(t *testing.T) {
|
|||
var bookRecord database.Book
|
||||
var bookCount, noteCount int
|
||||
var userRecord database.User
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, db.First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
|
||||
assert.Equalf(t, bookCount, 1, "book count mismatch")
|
||||
assert.Equalf(t, noteCount, 0, "note count mismatch")
|
||||
|
|
@ -489,17 +490,18 @@ func TestUpdateBook(t *testing.T) {
|
|||
|
||||
for idx, tc := range testCases {
|
||||
func() {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: tc.bookUUID,
|
||||
|
|
@ -507,15 +509,15 @@ func TestUpdateBook(t *testing.T) {
|
|||
Label: tc.bookLabel,
|
||||
Deleted: tc.bookDeleted,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
b2 := database.Book{
|
||||
UUID: b2UUID,
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b2), "preparing b2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
|
||||
|
||||
// Execute
|
||||
// Executdb,e
|
||||
endpoint := fmt.Sprintf("/v3/books/%s", tc.bookUUID)
|
||||
req := testutils.MakeReq(server, "PATCH", endpoint, tc.payload)
|
||||
res := testutils.HTTPAuthDo(t, req, user)
|
||||
|
|
@ -526,10 +528,10 @@ func TestUpdateBook(t *testing.T) {
|
|||
var bookRecord database.Book
|
||||
var userRecord database.User
|
||||
var noteCount, bookCount int
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
|
||||
assert.Equalf(t, bookCount, 2, "book count mismatch")
|
||||
assert.Equalf(t, noteCount, 0, "note count mismatch")
|
||||
|
|
|
|||
|
|
@ -23,10 +23,10 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/dnote/dnote/pkg/server/presenters"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
|
@ -48,7 +48,6 @@ func validateUpdateNotePayload(p updateNotePayload) bool {
|
|||
|
||||
// UpdateNote updates note
|
||||
func (a *App) UpdateNote(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
vars := mux.Vars(r)
|
||||
noteUUID := vars["noteUUID"]
|
||||
|
||||
|
|
@ -71,12 +70,12 @@ func (a *App) UpdateNote(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var note database.Note
|
||||
if err := db.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).First(¬e).Error; err != nil {
|
||||
if err := a.DB.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).First(¬e).Error; err != nil {
|
||||
HandleError(w, "finding note", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
note, err = operations.UpdateNote(tx, user, a.Clock, note, &operations.UpdateNoteParams{
|
||||
BookUUID: params.BookUUID,
|
||||
|
|
@ -116,8 +115,6 @@ type deleteNoteResp struct {
|
|||
|
||||
// DeleteNote removes note
|
||||
func (a *App) DeleteNote(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
|
||||
vars := mux.Vars(r)
|
||||
noteUUID := vars["noteUUID"]
|
||||
|
||||
|
|
@ -128,12 +125,12 @@ func (a *App) DeleteNote(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var note database.Note
|
||||
if err := db.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).Preload("Book").First(¬e).Error; err != nil {
|
||||
if err := a.DB.Where("uuid = ? AND user_id = ?", noteUUID, user.ID).Preload("Book").First(¬e).Error; err != nil {
|
||||
HandleError(w, "finding note", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
tx := a.DB.Begin()
|
||||
|
||||
n, err := operations.DeleteNote(tx, user, note)
|
||||
if err != nil {
|
||||
|
|
@ -193,13 +190,12 @@ func (a *App) CreateNote(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var book database.Book
|
||||
db := database.DBConn
|
||||
if err := db.Where("uuid = ? AND user_id = ?", params.BookUUID, user.ID).First(&book).Error; err != nil {
|
||||
if err := a.DB.Where("uuid = ? AND user_id = ?", params.BookUUID, user.ID).First(&book).Error; err != nil {
|
||||
HandleError(w, "finding book", err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
note, err := operations.CreateNote(user, a.Clock, params.BookUUID, params.Content, params.AddedOn, params.EditedOn, false)
|
||||
note, err := operations.CreateNote(a.DB, user, a.Clock, params.BookUUID, params.Content, params.AddedOn, params.EditedOn, false)
|
||||
if err != nil {
|
||||
HandleError(w, "creating note", err, http.StatusInternalServerError)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -29,29 +29,26 @@ import (
|
|||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func TestCreateNote(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
USN: 58,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
|
||||
// Execute
|
||||
dat := fmt.Sprintf(`{"book_uuid": "%s", "content": "note content"}`, b1.UUID)
|
||||
|
|
@ -65,11 +62,11 @@ func TestCreateNote(t *testing.T) {
|
|||
var bookRecord database.Book
|
||||
var userRecord database.User
|
||||
var bookCount, noteCount int
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, db.First(¬eRecord), "finding note")
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, testutils.DB.First(¬eRecord), "finding note")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
|
||||
assert.Equalf(t, bookCount, 1, "book count mismatch")
|
||||
assert.Equalf(t, noteCount, 1, "note count mismatch")
|
||||
|
|
@ -238,30 +235,31 @@ func TestUpdateNote(t *testing.T) {
|
|||
|
||||
for idx, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: b1UUID,
|
||||
UserID: user.ID,
|
||||
Label: "css",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
b2 := database.Book{
|
||||
UUID: b2UUID,
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b2), "preparing b2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
|
||||
|
||||
note := database.Note{
|
||||
UserID: user.ID,
|
||||
|
|
@ -271,7 +269,7 @@ func TestUpdateNote(t *testing.T) {
|
|||
Deleted: tc.noteDeleted,
|
||||
Public: tc.notePublic,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(¬e), "preparing note")
|
||||
testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note")
|
||||
|
||||
// Execute
|
||||
endpoint := fmt.Sprintf("/v3/notes/%s", note.UUID)
|
||||
|
|
@ -285,11 +283,11 @@ func TestUpdateNote(t *testing.T) {
|
|||
var noteRecord database.Note
|
||||
var userRecord database.User
|
||||
var noteCount, bookCount int
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
|
||||
assert.Equalf(t, bookCount, 2, "book count mismatch")
|
||||
assert.Equalf(t, noteCount, 1, "note count mismatch")
|
||||
|
|
@ -333,24 +331,25 @@ func TestDeleteNote(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) {
|
||||
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
// Setup
|
||||
server := MustNewServer(t, &App{
|
||||
Clock: clock.NewMock(),
|
||||
|
||||
Clock: clock.NewMock(),
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 981), "preparing user max_usn")
|
||||
|
||||
b1 := database.Book{
|
||||
UUID: b1UUID,
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
note := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
|
|
@ -358,7 +357,7 @@ func TestDeleteNote(t *testing.T) {
|
|||
Deleted: tc.deleted,
|
||||
USN: tc.originalUSN,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(¬e), "preparing note")
|
||||
testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note")
|
||||
|
||||
// Execute
|
||||
endpoint := fmt.Sprintf("/v3/notes/%s", note.UUID)
|
||||
|
|
@ -372,11 +371,11 @@ func TestDeleteNote(t *testing.T) {
|
|||
var noteRecord database.Note
|
||||
var userRecord database.User
|
||||
var bookCount, noteCount int
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
|
||||
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
|
||||
|
||||
assert.Equalf(t, bookCount, 1, "book count mismatch")
|
||||
assert.Equalf(t, noteCount, 1, "note count mismatch")
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/log"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
|
@ -121,14 +121,12 @@ func (e *queryParamError) Error() string {
|
|||
}
|
||||
|
||||
func (a *App) newFragment(userID, userMaxUSN, afterUSN, limit int) (SyncFragment, error) {
|
||||
db := database.DBConn
|
||||
|
||||
var notes []database.Note
|
||||
if err := db.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(¬es).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(¬es).Error; err != nil {
|
||||
return SyncFragment{}, nil
|
||||
}
|
||||
var books []database.Book
|
||||
if err := db.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(&books).Error; err != nil {
|
||||
if err := a.DB.Where("user_id = ? AND usn > ? AND usn <= ?", userID, afterUSN, userMaxUSN).Order("usn ASC").Limit(limit).Find(&books).Error; err != nil {
|
||||
return SyncFragment{}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@
|
|||
package helpers
|
||||
|
||||
import (
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
|
@ -29,8 +29,7 @@ const (
|
|||
)
|
||||
|
||||
// GetDemoUserID returns ID of the demo user
|
||||
func GetDemoUserID() (int, error) {
|
||||
db := database.DBConn
|
||||
func GetDemoUserID(db *gorm.DB) (int, error) {
|
||||
|
||||
result := struct {
|
||||
UserID int
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/job/repetition"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/robfig/cron"
|
||||
)
|
||||
|
|
@ -45,12 +46,12 @@ func checkEnvironment() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func schedule(ch chan error) {
|
||||
func schedule(db *gorm.DB, ch chan error) {
|
||||
cl := clock.New()
|
||||
|
||||
// Schedule jobs
|
||||
c := cron.New()
|
||||
scheduleJob(c, "* * * * *", func() { repetition.Do(cl) })
|
||||
scheduleJob(c, "* * * * *", func() { repetition.Do(db, cl) })
|
||||
c.Start()
|
||||
|
||||
ch <- nil
|
||||
|
|
@ -60,13 +61,13 @@ func schedule(ch chan error) {
|
|||
}
|
||||
|
||||
// Run starts the background tasks in a separate goroutine that runs forever
|
||||
func Run() error {
|
||||
func Run(db *gorm.DB) error {
|
||||
if err := checkEnvironment(); err != nil {
|
||||
return errors.Wrap(err, "checking environment variables")
|
||||
}
|
||||
|
||||
ch := make(chan error)
|
||||
go schedule(ch)
|
||||
go schedule(db, ch)
|
||||
if err := <-ch; err != nil {
|
||||
return errors.Wrap(err, "scheduling jobs")
|
||||
}
|
||||
|
|
|
|||
17
pkg/server/job/repetition/main_test.go
Normal file
17
pkg/server/job/repetition/main_test.go
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
package repetition
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutils.InitTestDB()
|
||||
|
||||
code := m.Run()
|
||||
testutils.ClearData()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
|
@ -31,20 +31,29 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// BuildEmailParams is the params for building an email
|
||||
type BuildEmailParams struct {
|
||||
Now time.Time
|
||||
User database.User
|
||||
EmailAddr string
|
||||
Digest database.Digest
|
||||
Rule database.RepetitionRule
|
||||
}
|
||||
|
||||
// BuildEmail builds an email for the spaced repetition
|
||||
func BuildEmail(now time.Time, user database.User, emailAddr string, digest database.Digest, rule database.RepetitionRule) (*mailer.Email, error) {
|
||||
date := now.Format("Jan 02 2006")
|
||||
subject := fmt.Sprintf("%s %s", rule.Title, date)
|
||||
tok, err := mailer.GetToken(user, database.TokenTypeRepetition)
|
||||
func BuildEmail(db *gorm.DB, p BuildEmailParams) (*mailer.Email, error) {
|
||||
date := p.Now.Format("Jan 02 2006")
|
||||
subject := fmt.Sprintf("%s %s", p.Rule.Title, date)
|
||||
tok, err := mailer.GetToken(db, p.User, database.TokenTypeRepetition)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting email frequency token")
|
||||
}
|
||||
|
||||
t1 := now.AddDate(0, 0, -3).UnixNano()
|
||||
t2 := now.AddDate(0, 0, -7).UnixNano()
|
||||
t1 := p.Now.AddDate(0, 0, -3).UnixNano()
|
||||
t2 := p.Now.AddDate(0, 0, -7).UnixNano()
|
||||
|
||||
noteInfos := []mailer.DigestNoteInfo{}
|
||||
for _, note := range digest.Notes {
|
||||
for _, note := range p.Digest.Notes {
|
||||
var stage int
|
||||
if note.AddedOn > t1 {
|
||||
stage = 1
|
||||
|
|
@ -60,7 +69,7 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
|
|||
|
||||
bookCount := 0
|
||||
bookMap := map[string]bool{}
|
||||
for _, n := range digest.Notes {
|
||||
for _, n := range p.Digest.Notes {
|
||||
if ok := bookMap[n.Book.Label]; !ok {
|
||||
bookCount++
|
||||
bookMap[n.Book.Label] = true
|
||||
|
|
@ -71,14 +80,14 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
|
|||
Subject: subject,
|
||||
NoteInfo: noteInfos,
|
||||
ActiveBookCount: bookCount,
|
||||
ActiveNoteCount: len(digest.Notes),
|
||||
ActiveNoteCount: len(p.Digest.Notes),
|
||||
EmailSessionToken: tok.Value,
|
||||
RuleUUID: rule.UUID,
|
||||
RuleTitle: rule.Title,
|
||||
RuleUUID: p.Rule.UUID,
|
||||
RuleTitle: p.Rule.Title,
|
||||
WebURL: os.Getenv("WebURL"),
|
||||
}
|
||||
|
||||
email := mailer.NewEmail("noreply@getdnote.com", []string{emailAddr}, subject)
|
||||
email := mailer.NewEmail("noreply@getdnote.com", []string{p.EmailAddr}, subject)
|
||||
if err := email.ParseTemplate(mailer.EmailTypeWeeklyDigest, tmplData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -86,12 +95,11 @@ func BuildEmail(now time.Time, user database.User, emailAddr string, digest data
|
|||
return email, nil
|
||||
}
|
||||
|
||||
func getEligibleRules(now time.Time) ([]database.RepetitionRule, error) {
|
||||
func getEligibleRules(db *gorm.DB, now time.Time) ([]database.RepetitionRule, error) {
|
||||
hour := now.Hour()
|
||||
minute := now.Minute()
|
||||
|
||||
var ret []database.RepetitionRule
|
||||
db := database.DBConn
|
||||
if err := db.
|
||||
Where("users.cloud AND repetition_rules.hour = ? AND repetition_rules.minute = ? AND repetition_rules.enabled", hour, minute).
|
||||
Joins("INNER JOIN users ON users.id = repetition_rules.user_id").
|
||||
|
|
@ -120,9 +128,7 @@ func build(tx *gorm.DB, rule database.RepetitionRule) (database.Digest, error) {
|
|||
return digest, nil
|
||||
}
|
||||
|
||||
func notify(now time.Time, user database.User, digest database.Digest, rule database.RepetitionRule) error {
|
||||
db := database.DBConn
|
||||
|
||||
func notify(db *gorm.DB, now time.Time, user database.User, digest database.Digest, rule database.RepetitionRule) error {
|
||||
var account database.Account
|
||||
if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
||||
return errors.Wrap(err, "getting account")
|
||||
|
|
@ -135,7 +141,13 @@ func notify(now time.Time, user database.User, digest database.Digest, rule data
|
|||
return nil
|
||||
}
|
||||
|
||||
email, err := BuildEmail(now, user, account.Email.String, digest, rule)
|
||||
email, err := BuildEmail(db, BuildEmailParams{
|
||||
Now: now,
|
||||
User: user,
|
||||
EmailAddr: account.Email.String,
|
||||
Digest: digest,
|
||||
Rule: rule,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "making email")
|
||||
}
|
||||
|
|
@ -185,12 +197,11 @@ func touchTimestamp(tx *gorm.DB, rule database.RepetitionRule, now time.Time) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func process(now time.Time, rule database.RepetitionRule) error {
|
||||
func process(db *gorm.DB, now time.Time, rule database.RepetitionRule) error {
|
||||
log.WithFields(log.Fields{
|
||||
"uuid": rule.UUID,
|
||||
}).Info("processing repetition")
|
||||
|
||||
db := database.DBConn
|
||||
tx := db.Begin()
|
||||
|
||||
if !checkCooldown(now, rule) {
|
||||
|
|
@ -224,7 +235,7 @@ func process(now time.Time, rule database.RepetitionRule) error {
|
|||
return errors.Wrap(err, "committing transaction")
|
||||
}
|
||||
|
||||
if err := notify(now, user, digest, rule); err != nil {
|
||||
if err := notify(db, now, user, digest, rule); err != nil {
|
||||
return errors.Wrap(err, "notifying user")
|
||||
}
|
||||
|
||||
|
|
@ -236,10 +247,10 @@ func process(now time.Time, rule database.RepetitionRule) error {
|
|||
}
|
||||
|
||||
// Do creates spaced repetitions and delivers the results based on the rules
|
||||
func Do(c clock.Clock) error {
|
||||
func Do(db *gorm.DB, c clock.Clock) error {
|
||||
now := c.Now().UTC()
|
||||
|
||||
rules, err := getEligibleRules(now)
|
||||
rules, err := getEligibleRules(db, now)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "getting eligible repetition rules")
|
||||
}
|
||||
|
|
@ -251,7 +262,7 @@ func Do(c clock.Clock) error {
|
|||
}).Info("processing rules")
|
||||
|
||||
for _, rule := range rules {
|
||||
if err := process(now, rule); err != nil {
|
||||
if err := process(db, now, rule); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"rule uuid": rule.UUID,
|
||||
}).ErrorWrap(err, "Could not process the repetition rule")
|
||||
|
|
|
|||
|
|
@ -29,24 +29,18 @@ import (
|
|||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func assertLastActive(t *testing.T, ruleUUID string, lastActive int64) {
|
||||
db := database.DBConn
|
||||
|
||||
var rule database.RepetitionRule
|
||||
testutils.MustExec(t, db.Where("uuid = ?", ruleUUID).First(&rule), "finding rule1")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", ruleUUID).First(&rule), "finding rule1")
|
||||
|
||||
assert.Equal(t, rule.LastActive, lastActive, "LastActive mismatch")
|
||||
}
|
||||
|
||||
func assertDigestCount(t *testing.T, rule database.RepetitionRule, expected int) {
|
||||
db := database.DBConn
|
||||
|
||||
var digestCount int
|
||||
testutils.MustExec(t, db.Model(&database.Digest{}).Where("rule_id = ? AND user_id = ?", rule.ID, rule.UserID).Count(&digestCount), "counting digest")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Digest{}).Where("rule_id = ? AND user_id = ?", rule.ID, rule.UserID).Count(&digestCount), "counting digest")
|
||||
assert.Equal(t, digestCount, expected, "digest count mismatch")
|
||||
}
|
||||
|
||||
|
|
@ -74,68 +68,67 @@ func TestDo(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
|
||||
|
||||
c := clock.NewMock()
|
||||
|
||||
// Test
|
||||
// 1 day later
|
||||
c.SetNow(time.Date(2009, time.November, 2, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(0))
|
||||
assertDigestCount(t, r1, 0)
|
||||
|
||||
// 2 days later
|
||||
c.SetNow(time.Date(2009, time.November, 3, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(0))
|
||||
assertDigestCount(t, r1, 0)
|
||||
|
||||
// 3 days later - should be processed
|
||||
c.SetNow(time.Date(2009, time.November, 4, 12, 1, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(0))
|
||||
assertDigestCount(t, r1, 0)
|
||||
|
||||
c.SetNow(time.Date(2009, time.November, 4, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(1257336120000))
|
||||
assertDigestCount(t, r1, 1)
|
||||
|
||||
c.SetNow(time.Date(2009, time.November, 4, 12, 3, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(1257336120000))
|
||||
assertDigestCount(t, r1, 1)
|
||||
|
||||
// 4 day later
|
||||
c.SetNow(time.Date(2009, time.November, 5, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(1257336120000))
|
||||
assertDigestCount(t, r1, 1)
|
||||
// 5 days later
|
||||
c.SetNow(time.Date(2009, time.November, 6, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(1257336120000))
|
||||
assertDigestCount(t, r1, 1)
|
||||
// 6 days later - should be processed
|
||||
c.SetNow(time.Date(2009, time.November, 7, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(1257595320000))
|
||||
assertDigestCount(t, r1, 2)
|
||||
// 7 days later
|
||||
c.SetNow(time.Date(2009, time.November, 8, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(1257595320000))
|
||||
assertDigestCount(t, r1, 2)
|
||||
// 8 days later
|
||||
c.SetNow(time.Date(2009, time.November, 9, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(1257595320000))
|
||||
assertDigestCount(t, r1, 2)
|
||||
// 9 days later - should be processed
|
||||
c.SetNow(time.Date(2009, time.November, 10, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
assertLastActive(t, r1.UUID, int64(1257854520000))
|
||||
assertDigestCount(t, r1, 3)
|
||||
})
|
||||
|
|
@ -177,15 +170,14 @@ func TestDo(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
|
||||
|
||||
c := clock.NewMock()
|
||||
c.SetNow(time.Date(2009, time.November, 10, 12, 2, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
|
||||
var rule database.RepetitionRule
|
||||
testutils.MustExec(t, db.Where("uuid = ?", r1.UUID).First(&rule), "finding rule1")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", r1.UUID).First(&rule), "finding rule1")
|
||||
|
||||
assert.Equal(t, rule.LastActive, time.Date(2009, time.November, 10, 12, 2, 0, 0, time.UTC).UnixNano()/int64(time.Millisecond), "LastActive mismsatch")
|
||||
assert.Equal(t, rule.NextActive, time.Date(2009, time.November, 13, 12, 2, 0, 0, time.UTC).UnixNano()/int64(time.Millisecond), "NextActive mismsatch")
|
||||
|
|
@ -216,13 +208,12 @@ func TestDo_Disabled(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
|
||||
|
||||
// Execute
|
||||
c := clock.NewMock()
|
||||
c.SetNow(time.Date(2009, time.November, 4, 12, 2, 0, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
|
||||
// Test
|
||||
assertLastActive(t, r1.UUID, int64(0))
|
||||
|
|
@ -241,40 +232,39 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
}
|
||||
|
||||
setup := func() testData {
|
||||
db := database.DBConn
|
||||
user := testutils.SetupUserData()
|
||||
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
b2 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "css",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b2), "preparing b2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
|
||||
b3 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "golang",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b3), "preparing b3")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
|
||||
|
||||
n1 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n1), "preparing n1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
|
||||
n2 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b2.UUID,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n2), "preparing n2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
|
||||
n3 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b3.UUID,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n3), "preparing n3")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
|
||||
|
||||
return testData{
|
||||
User: user,
|
||||
|
|
@ -293,7 +283,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
// Set up
|
||||
dat := setup()
|
||||
|
||||
db := database.DBConn
|
||||
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
|
||||
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
|
||||
r1 := database.RepetitionRule{
|
||||
|
|
@ -312,20 +301,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
UpdatedAt: t0,
|
||||
},
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
|
||||
|
||||
// Execute
|
||||
c := clock.NewMock()
|
||||
|
||||
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
|
||||
// Test
|
||||
assertLastActive(t, r1.UUID, int64(1257714000000))
|
||||
assertDigestCount(t, r1, 1)
|
||||
|
||||
var repetition database.Digest
|
||||
testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
|
||||
testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
|
||||
|
||||
sort.SliceStable(repetition.Notes, func(i, j int) bool {
|
||||
n1 := repetition.Notes[i]
|
||||
|
|
@ -335,9 +324,9 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
})
|
||||
|
||||
var n1Record, n2Record, n3Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
|
||||
expected := []database.Note{n1Record, n2Record, n3Record}
|
||||
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
|
||||
})
|
||||
|
|
@ -348,7 +337,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
// Set up
|
||||
dat := setup()
|
||||
|
||||
db := database.DBConn
|
||||
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
|
||||
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
|
||||
r1 := database.RepetitionRule{
|
||||
|
|
@ -368,20 +356,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
UpdatedAt: t0,
|
||||
},
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
|
||||
|
||||
// Execute
|
||||
c := clock.NewMock()
|
||||
|
||||
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 1, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
|
||||
// Test
|
||||
assertLastActive(t, r1.UUID, int64(1257714000000))
|
||||
assertDigestCount(t, r1, 1)
|
||||
|
||||
var repetition database.Digest
|
||||
testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
|
||||
testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
|
||||
|
||||
sort.SliceStable(repetition.Notes, func(i, j int) bool {
|
||||
n1 := repetition.Notes[i]
|
||||
|
|
@ -391,8 +379,8 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
})
|
||||
|
||||
var n2Record, n3Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note3.UUID).First(&n3Record), "finding n3")
|
||||
expected := []database.Note{n2Record, n3Record}
|
||||
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
|
||||
})
|
||||
|
|
@ -403,7 +391,6 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
// Set up
|
||||
dat := setup()
|
||||
|
||||
db := database.DBConn
|
||||
t0 := time.Date(2009, time.November, 1, 12, 0, 0, 0, time.UTC)
|
||||
t1 := time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC)
|
||||
r1 := database.RepetitionRule{
|
||||
|
|
@ -423,20 +410,20 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
UpdatedAt: t0,
|
||||
},
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&r1), "preparing rule1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&r1), "preparing rule1")
|
||||
|
||||
// Execute
|
||||
c := clock.NewMock()
|
||||
|
||||
c.SetNow(time.Date(2009, time.November, 8, 21, 0, 0, 0, time.UTC))
|
||||
Do(c)
|
||||
Do(testutils.DB, c)
|
||||
|
||||
// Test
|
||||
assertLastActive(t, r1.UUID, int64(1257714000000))
|
||||
assertDigestCount(t, r1, 1)
|
||||
|
||||
var repetition database.Digest
|
||||
testutils.MustExec(t, db.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
|
||||
testutils.MustExec(t, testutils.DB.Where("rule_id = ? AND user_id = ?", r1.ID, r1.UserID).Preload("Notes").First(&repetition), "finding repetition")
|
||||
|
||||
sort.SliceStable(repetition.Notes, func(i, j int) bool {
|
||||
n1 := repetition.Notes[i]
|
||||
|
|
@ -446,8 +433,8 @@ func TestDo_BalancedStrategy(t *testing.T) {
|
|||
})
|
||||
|
||||
var n1Record, n2Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note1.UUID).First(&n1Record), "finding n1")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", dat.Note2.UUID).First(&n2Record), "finding n2")
|
||||
expected := []database.Note{n1Record, n2Record}
|
||||
assert.DeepEqual(t, repetition.Notes, expected, "result mismatch")
|
||||
})
|
||||
|
|
|
|||
|
|
@ -27,8 +27,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func getRuleBookIDs(ruleID int) ([]int, error) {
|
||||
db := database.DBConn
|
||||
func getRuleBookIDs(db *gorm.DB, ruleID int) ([]int, error) {
|
||||
var ret []int
|
||||
if err := db.Table("repetition_rule_books").Select("book_id").Where("repetition_rule_id = ?", ruleID).Pluck("book_id", &ret).Error; err != nil {
|
||||
return nil, errors.Wrap(err, "querying book_ids")
|
||||
|
|
@ -37,11 +36,11 @@ func getRuleBookIDs(ruleID int) ([]int, error) {
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func applyBookDomain(noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB, error) {
|
||||
func applyBookDomain(db *gorm.DB, noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB, error) {
|
||||
ret := noteQuery
|
||||
|
||||
if rule.BookDomain != database.BookDomainAll {
|
||||
bookIDs, err := getRuleBookIDs(rule.ID)
|
||||
bookIDs, err := getRuleBookIDs(db, rule.ID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting book_ids")
|
||||
}
|
||||
|
|
@ -58,8 +57,8 @@ func applyBookDomain(noteQuery *gorm.DB, rule database.RepetitionRule) (*gorm.DB
|
|||
return ret, nil
|
||||
}
|
||||
|
||||
func getNotes(conn *gorm.DB, rule database.RepetitionRule, dst *[]database.Note) error {
|
||||
c, err := applyBookDomain(conn, rule)
|
||||
func getNotes(db, conn *gorm.DB, rule database.RepetitionRule, dst *[]database.Note) error {
|
||||
c, err := applyBookDomain(db, conn, rule)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "building query for book threahold 1")
|
||||
}
|
||||
|
|
@ -79,16 +78,14 @@ func getBalancedNotes(db *gorm.DB, rule database.RepetitionRule) ([]database.Not
|
|||
t2 := now.AddDate(0, 0, -7).UnixNano()
|
||||
|
||||
// Get notes into three buckets with different threshold values
|
||||
var stage1 []database.Note
|
||||
var stage2 []database.Note
|
||||
var stage3 []database.Note
|
||||
if err := getNotes(db.Where("notes.added_on > ?", t1), rule, &stage1); err != nil {
|
||||
var stage1, stage2, stage3 []database.Note
|
||||
if err := getNotes(db, db.Where("notes.added_on > ?", t1), rule, &stage1); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to get notes with threshold 1")
|
||||
}
|
||||
if err := getNotes(db.Where("notes.added_on > ? AND notes.added_on < ?", t2, t1), rule, &stage2); err != nil {
|
||||
if err := getNotes(db, db.Where("notes.added_on > ? AND notes.added_on < ?", t2, t1), rule, &stage2); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to get notes with threshold 2")
|
||||
}
|
||||
if err := getNotes(db.Where("notes.added_on < ?", t2), rule, &stage3); err != nil {
|
||||
if err := getNotes(db, db.Where("notes.added_on < ?", t2), rule, &stage3); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to get notes with threshold 3")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,45 +34,43 @@ func init() {
|
|||
func TestApplyBookDomain(t *testing.T) {
|
||||
defer testutils.ClearData()
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
b2 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "css",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b2), "preparing b2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
|
||||
b3 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "golang",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b3), "preparing b3")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
|
||||
|
||||
n1 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n1), "preparing n1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
|
||||
n2 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b2.UUID,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n2), "preparing n2")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
|
||||
n3 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b3.UUID,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n3), "preparing n3")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
|
||||
|
||||
var n1Record, n2Record, n3Record database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", n3.UUID).First(&n3Record), "finding n3")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n3.UUID).First(&n3Record), "finding n3")
|
||||
|
||||
t.Run("book domain all", func(t *testing.T) {
|
||||
rule := database.RepetitionRule{
|
||||
|
|
@ -80,7 +78,7 @@ func TestApplyBookDomain(t *testing.T) {
|
|||
BookDomain: database.BookDomainAll,
|
||||
}
|
||||
|
||||
conn, err := applyBookDomain(db, rule)
|
||||
conn, err := applyBookDomain(testutils.DB, testutils.DB, rule)
|
||||
if err != nil {
|
||||
t.Fatal(errors.Wrap(err, "executing").Error())
|
||||
}
|
||||
|
|
@ -98,9 +96,9 @@ func TestApplyBookDomain(t *testing.T) {
|
|||
BookDomain: database.BookDomainExluding,
|
||||
Books: []database.Book{b1},
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&rule), "preparing rule")
|
||||
testutils.MustExec(t, testutils.DB.Save(&rule), "preparing rule")
|
||||
|
||||
conn, err := applyBookDomain(db.Debug(), rule)
|
||||
conn, err := applyBookDomain(testutils.DB, testutils.DB, rule)
|
||||
if err != nil {
|
||||
t.Fatal(errors.Wrap(err, "executing").Error())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,15 +25,17 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/dbconn"
|
||||
"github.com/dnote/dnote/pkg/server/job/repetition"
|
||||
"github.com/dnote/dnote/pkg/server/mailer"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/joho/godotenv"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func digestHandler(w http.ResponseWriter, r *http.Request) {
|
||||
db := database.DBConn
|
||||
func (c Context) digestHandler(w http.ResponseWriter, r *http.Request) {
|
||||
db := c.DB
|
||||
|
||||
q := r.URL.Query()
|
||||
digestUUID := q.Get("digest_uuid")
|
||||
|
|
@ -61,7 +63,13 @@ func digestHandler(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
now := time.Now()
|
||||
email, err := repetition.BuildEmail(now, user, "sung@getdnote.com", digest, rule)
|
||||
email, err := repetition.BuildEmail(db, repetition.BuildEmailParams{
|
||||
Now: now,
|
||||
User: user,
|
||||
EmailAddr: "sung@getdnote.com",
|
||||
Digest: digest,
|
||||
Rule: rule,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
|
@ -71,7 +79,7 @@ func digestHandler(w http.ResponseWriter, r *http.Request) {
|
|||
w.Write([]byte(body))
|
||||
}
|
||||
|
||||
func emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (c Context) emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
|
||||
data := struct {
|
||||
Subject string
|
||||
Token string
|
||||
|
|
@ -90,7 +98,7 @@ func emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
|
|||
w.Write([]byte(body))
|
||||
}
|
||||
|
||||
func homeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (c Context) homeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Email development server is running."))
|
||||
}
|
||||
|
||||
|
|
@ -101,23 +109,29 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
// Context is a context holding global information
|
||||
type Context struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func main() {
|
||||
c := database.Config{
|
||||
db := dbconn.Open(dbconn.Config{
|
||||
Host: os.Getenv("DBHost"),
|
||||
Port: os.Getenv("DBPort"),
|
||||
Name: os.Getenv("DBName"),
|
||||
User: os.Getenv("DBUser"),
|
||||
Password: os.Getenv("DBPassword"),
|
||||
}
|
||||
database.Open(c)
|
||||
defer database.Close()
|
||||
})
|
||||
defer db.Close()
|
||||
|
||||
mailer.InitTemplates(nil)
|
||||
|
||||
log.Println("Email template development server running on http://127.0.0.1:2300")
|
||||
|
||||
http.HandleFunc("/", homeHandler)
|
||||
http.HandleFunc("/digest", digestHandler)
|
||||
http.HandleFunc("/email-verification", emailVerificationHandler)
|
||||
ctx := Context{DB: db}
|
||||
|
||||
http.HandleFunc("/", ctx.homeHandler)
|
||||
http.HandleFunc("/digest", ctx.digestHandler)
|
||||
http.HandleFunc("/email-verification", ctx.emailVerificationHandler)
|
||||
log.Fatal(http.ListenAndServe(":2300", nil))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
"encoding/base64"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
|
@ -39,9 +40,7 @@ func generateRandomToken(bits int) (string, error) {
|
|||
|
||||
// GetToken returns an token of the given kind for the user
|
||||
// by first looking up any unused record and creating one if none exists.
|
||||
func GetToken(user database.User, kind string) (database.Token, error) {
|
||||
db := database.DBConn
|
||||
|
||||
func GetToken(db *gorm.DB, user database.User, kind string) (database.Token, error) {
|
||||
var tok database.Token
|
||||
conn := db.
|
||||
Where("user_id = ? AND type =? AND used_at IS NULL", user.ID, kind).
|
||||
|
|
|
|||
|
|
@ -27,10 +27,12 @@ import (
|
|||
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/dbconn"
|
||||
"github.com/dnote/dnote/pkg/server/handlers"
|
||||
"github.com/dnote/dnote/pkg/server/job"
|
||||
"github.com/dnote/dnote/pkg/server/mailer"
|
||||
"github.com/dnote/dnote/pkg/server/web"
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"github.com/gobuffalo/packr/v2"
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -64,49 +66,70 @@ func initContext() web.Context {
|
|||
}
|
||||
}
|
||||
|
||||
func initServer() (*http.ServeMux, error) {
|
||||
apiRouter, err := handlers.NewRouter(&handlers.App{
|
||||
Clock: clock.New(),
|
||||
StripeAPIBackend: nil,
|
||||
WebURL: os.Getenv("WebURL"),
|
||||
})
|
||||
func initServer(app handlers.App) (*http.ServeMux, error) {
|
||||
apiRouter, err := handlers.NewRouter(&app)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "initializing router")
|
||||
}
|
||||
|
||||
ctx := initContext()
|
||||
webCtx := initContext()
|
||||
webHandlers := web.Init(webCtx)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/api/", http.StripPrefix("/api", apiRouter))
|
||||
mux.Handle("/static/", web.GetStaticHandler(ctx.StaticFileSystem))
|
||||
mux.HandleFunc("/service-worker.js", web.GetSWHandler(ctx.ServiceWorkerJs))
|
||||
mux.HandleFunc("/robots.txt", web.GetRobotsHandler(ctx.RobotsTxt))
|
||||
mux.HandleFunc("/", web.GetRootHandler(ctx.IndexHTML))
|
||||
mux.Handle("/static/", webHandlers.GetStatic)
|
||||
mux.HandleFunc("/service-worker.js", webHandlers.GetServiceWorker)
|
||||
mux.HandleFunc("/robots.txt", webHandlers.GetRobots)
|
||||
mux.HandleFunc("/", webHandlers.GetRoot)
|
||||
|
||||
return mux, nil
|
||||
}
|
||||
|
||||
func startCmd() {
|
||||
mailer.InitTemplates(nil)
|
||||
func initDB() *gorm.DB {
|
||||
var skipSSL bool
|
||||
if os.Getenv("GO_ENV") != "PRODUCTION" || os.Getenv("DB_NOSSL") != "" {
|
||||
skipSSL = true
|
||||
} else {
|
||||
skipSSL = false
|
||||
}
|
||||
|
||||
database.Open(database.Config{
|
||||
db := dbconn.Open(dbconn.Config{
|
||||
SkipSSL: skipSSL,
|
||||
Host: os.Getenv("DBHost"),
|
||||
Port: os.Getenv("DBPort"),
|
||||
Name: os.Getenv("DBName"),
|
||||
User: os.Getenv("DBUser"),
|
||||
Password: os.Getenv("DBPassword"),
|
||||
})
|
||||
database.InitSchema()
|
||||
defer database.Close()
|
||||
database.InitSchema(db)
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
return db
|
||||
}
|
||||
|
||||
func initApp(db *gorm.DB) handlers.App {
|
||||
return handlers.App{
|
||||
DB: db,
|
||||
Clock: clock.New(),
|
||||
StripeAPIBackend: nil,
|
||||
WebURL: os.Getenv("WebURL"),
|
||||
}
|
||||
}
|
||||
|
||||
func startCmd() {
|
||||
db := initDB()
|
||||
defer db.Close()
|
||||
|
||||
app := initApp(db)
|
||||
mailer.InitTemplates(nil)
|
||||
|
||||
if err := database.Migrate(app.DB); err != nil {
|
||||
panic(errors.Wrap(err, "running migrations"))
|
||||
}
|
||||
if err := job.Run(); err != nil {
|
||||
if err := job.Run(db); err != nil {
|
||||
panic(errors.Wrap(err, "running job"))
|
||||
}
|
||||
|
||||
srv, err := initServer()
|
||||
srv, err := initServer(app)
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "initializing server"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,15 +20,14 @@ package operations
|
|||
|
||||
import (
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// CreateBook creates a book with the next usn and updates the user's max_usn
|
||||
func CreateBook(user database.User, clock clock.Clock, name string) (database.Book, error) {
|
||||
db := database.DBConn
|
||||
func CreateBook(db *gorm.DB, user database.User, clock clock.Clock, name string) (database.Book, error) {
|
||||
tx := db.Begin()
|
||||
|
||||
nextUSN, err := incrementUserUSN(tx, user.ID)
|
||||
|
|
|
|||
|
|
@ -29,10 +29,6 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func TestCreateBook(t *testing.T) {
|
||||
testCases := []struct {
|
||||
userUSN int
|
||||
|
|
@ -59,17 +55,16 @@ func TestCreateBook(t *testing.T) {
|
|||
for idx, tc := range testCases {
|
||||
func() {
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
c := clock.NewMock()
|
||||
|
||||
book, err := CreateBook(user, c, tc.label)
|
||||
book, err := CreateBook(testutils.DB, user, c, tc.label)
|
||||
if err != nil {
|
||||
t.Fatal(errors.Wrap(err, "creating book"))
|
||||
}
|
||||
|
|
@ -78,13 +73,13 @@ func TestCreateBook(t *testing.T) {
|
|||
var bookRecord database.Book
|
||||
var userRecord database.User
|
||||
|
||||
if err := db.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
|
||||
if err := testutils.DB.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
|
||||
t.Fatal(errors.Wrap(err, "counting books"))
|
||||
}
|
||||
if err := db.First(&bookRecord).Error; err != nil {
|
||||
if err := testutils.DB.First(&bookRecord).Error; err != nil {
|
||||
t.Fatal(errors.Wrap(err, "finding book"))
|
||||
}
|
||||
if err := db.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
|
||||
if err := testutils.DB.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
|
||||
t.Fatal(errors.Wrap(err, "finding user"))
|
||||
}
|
||||
|
||||
|
|
@ -124,18 +119,17 @@ func TestDeleteBook(t *testing.T) {
|
|||
for idx, tc := range testCases {
|
||||
func() {
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
book := database.Book{UserID: user.ID, Label: "js", Deleted: false}
|
||||
testutils.MustExec(t, db.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
|
||||
|
||||
tx := db.Begin()
|
||||
tx := testutils.DB.Begin()
|
||||
ret, err := DeleteBook(tx, user, book)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
|
@ -147,9 +141,9 @@ func TestDeleteBook(t *testing.T) {
|
|||
var bookRecord database.Book
|
||||
var userRecord database.User
|
||||
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
|
||||
testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
|
||||
assert.Equal(t, bookCount, 1, "book count mismatch")
|
||||
assert.Equal(t, bookRecord.UserID, user.ID, "book user_id mismatch")
|
||||
|
|
@ -202,20 +196,19 @@ func TestUpdateBook(t *testing.T) {
|
|||
for idx, tc := range testCases {
|
||||
func() {
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
c := clock.NewMock()
|
||||
|
||||
b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel}
|
||||
testutils.MustExec(t, db.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
|
||||
|
||||
tx := db.Begin()
|
||||
tx := testutils.DB.Begin()
|
||||
|
||||
book, err := UpdateBook(tx, c, user, b, tc.payloadLabel)
|
||||
if err != nil {
|
||||
|
|
@ -228,9 +221,9 @@ func TestUpdateBook(t *testing.T) {
|
|||
var bookCount int
|
||||
var bookRecord database.Book
|
||||
var userRecord database.User
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
|
||||
testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
|
||||
assert.Equal(t, bookCount, 1, "book count mismatch")
|
||||
|
||||
|
|
|
|||
|
|
@ -28,10 +28,6 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func TestIncremenetUserUSN(t *testing.T) {
|
||||
testCases := []struct {
|
||||
maxUSN int
|
||||
|
|
@ -51,13 +47,12 @@ func TestIncremenetUserUSN(t *testing.T) {
|
|||
for idx, tc := range testCases {
|
||||
func() {
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
// execute
|
||||
tx := db.Begin()
|
||||
tx := testutils.DB.Begin()
|
||||
nextUSN, err := incrementUserUSN(tx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatal(errors.Wrap(err, "incrementing the user usn"))
|
||||
|
|
@ -66,7 +61,7 @@ func TestIncremenetUserUSN(t *testing.T) {
|
|||
|
||||
// test
|
||||
var userRecord database.User
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
|
||||
assert.Equal(t, userRecord.MaxUSN, tc.expectedMaxUSN, fmt.Sprintf("user max_usn mismatch for case %d", idx))
|
||||
assert.Equal(t, nextUSN, tc.expectedMaxUSN, fmt.Sprintf("next_usn mismatch for case %d", idx))
|
||||
|
|
|
|||
17
pkg/server/operations/main_test.go
Normal file
17
pkg/server/operations/main_test.go
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
package operations
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutils.InitTestDB()
|
||||
|
||||
code := m.Run()
|
||||
testutils.ClearData()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
|
@ -29,8 +29,7 @@ import (
|
|||
|
||||
// CreateNote creates a note with the next usn and updates the user's max_usn.
|
||||
// It returns the created note.
|
||||
func CreateNote(user database.User, clock clock.Clock, bookUUID, content string, addedOn *int64, editedOn *int64, public bool) (database.Note, error) {
|
||||
db := database.DBConn
|
||||
func CreateNote(db *gorm.DB, user database.User, clock clock.Clock, bookUUID, content string, addedOn *int64, editedOn *int64, public bool) (database.Note, error) {
|
||||
tx := db.Begin()
|
||||
|
||||
nextUSN, err := incrementUserUSN(tx, user.ID)
|
||||
|
|
@ -163,14 +162,12 @@ func DeleteNote(tx *gorm.DB, user database.User, note database.Note) (database.N
|
|||
}
|
||||
|
||||
// GetNote retrieves a note for the given user
|
||||
func GetNote(uuid string, user database.User) (database.Note, bool, error) {
|
||||
func GetNote(db *gorm.DB, uuid string, user database.User) (database.Note, bool, error) {
|
||||
zeroNote := database.Note{}
|
||||
if !helpers.ValidateUUID(uuid) {
|
||||
return zeroNote, false, nil
|
||||
}
|
||||
|
||||
db := database.DBConn
|
||||
|
||||
conn := db.Where("notes.uuid = ? AND deleted = ?", uuid, false)
|
||||
conn = database.PreloadNote(conn)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,10 +30,6 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func TestCreateNote(t *testing.T) {
|
||||
serverTime := time.Date(2017, time.March, 14, 21, 15, 0, 0, time.UTC)
|
||||
mockClock := clock.NewMock()
|
||||
|
|
@ -79,19 +75,18 @@ func TestCreateNote(t *testing.T) {
|
|||
for idx, tc := range testCases {
|
||||
func() {
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
|
||||
testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
|
||||
|
||||
tx := db.Begin()
|
||||
if _, err := CreateNote(user, mockClock, b1.UUID, "note content", tc.addedOn, tc.editedOn, false); err != nil {
|
||||
tx := testutils.DB.Begin()
|
||||
if _, err := CreateNote(testutils.DB, user, mockClock, b1.UUID, "note content", tc.addedOn, tc.editedOn, false); err != nil {
|
||||
tx.Rollback()
|
||||
t.Fatal(errors.Wrap(err, "deleting note"))
|
||||
}
|
||||
|
|
@ -101,10 +96,10 @@ func TestCreateNote(t *testing.T) {
|
|||
var noteRecord database.Note
|
||||
var userRecord database.User
|
||||
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
|
||||
testutils.MustExec(t, db.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
|
||||
assert.Equal(t, bookCount, 1, "book count mismatch")
|
||||
assert.Equal(t, noteCount, 1, "note count mismatch")
|
||||
|
|
@ -139,25 +134,24 @@ func TestUpdateNote(t *testing.T) {
|
|||
for idx, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
|
||||
|
||||
anotherUser := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
|
||||
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
|
||||
|
||||
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1 for test case")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1 for test case")
|
||||
|
||||
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
|
||||
testutils.MustExec(t, db.Save(¬e), "preparing note for test case")
|
||||
testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note for test case")
|
||||
|
||||
c := clock.NewMock()
|
||||
content := "updated test content"
|
||||
public := true
|
||||
|
||||
tx := db.Begin()
|
||||
tx := testutils.DB.Begin()
|
||||
if _, err := UpdateNote(tx, user, c, note, &UpdateNoteParams{
|
||||
Content: &content,
|
||||
Public: &public,
|
||||
|
|
@ -171,10 +165,10 @@ func TestUpdateNote(t *testing.T) {
|
|||
var noteRecord database.Note
|
||||
var userRecord database.User
|
||||
|
||||
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes for test case")
|
||||
testutils.MustExec(t, db.First(¬eRecord), "finding note for test case")
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes for test case")
|
||||
testutils.MustExec(t, testutils.DB.First(¬eRecord), "finding note for test case")
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
|
||||
|
||||
expectedUSN := tc.userUSN + 1
|
||||
assert.Equal(t, bookCount, 1, "book count mismatch")
|
||||
|
|
@ -211,21 +205,20 @@ func TestDeleteNote(t *testing.T) {
|
|||
for idx, tc := range testCases {
|
||||
func() {
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
anotherUser := testutils.SetupUserData()
|
||||
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
|
||||
|
||||
b1 := database.Book{UserID: user.ID, Label: "testBook"}
|
||||
testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
|
||||
|
||||
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
|
||||
testutils.MustExec(t, db.Save(¬e), fmt.Sprintf("preparing note for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Save(¬e), fmt.Sprintf("preparing note for test case %d", idx))
|
||||
|
||||
tx := db.Begin()
|
||||
tx := testutils.DB.Begin()
|
||||
ret, err := DeleteNote(tx, user, note)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
|
@ -237,9 +230,9 @@ func TestDeleteNote(t *testing.T) {
|
|||
var noteRecord database.Note
|
||||
var userRecord database.User
|
||||
|
||||
testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
|
||||
testutils.MustExec(t, db.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
|
||||
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
|
||||
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
|
||||
|
||||
assert.Equal(t, noteCount, 1, "note count mismatch")
|
||||
|
||||
|
|
@ -261,14 +254,13 @@ func TestGetNote(t *testing.T) {
|
|||
user := testutils.SetupUserData()
|
||||
anotherUser := testutils.SetupUserData()
|
||||
|
||||
db := database.DBConn
|
||||
defer testutils.ClearData()
|
||||
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
|
||||
privateNote := database.Note{
|
||||
UserID: user.ID,
|
||||
|
|
@ -277,7 +269,7 @@ func TestGetNote(t *testing.T) {
|
|||
Deleted: false,
|
||||
Public: false,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
|
||||
testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
|
||||
|
||||
publicNote := database.Note{
|
||||
UserID: user.ID,
|
||||
|
|
@ -286,11 +278,11 @@ func TestGetNote(t *testing.T) {
|
|||
Deleted: false,
|
||||
Public: true,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote")
|
||||
testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote")
|
||||
|
||||
var privateNoteRecord, publicNoteRecord database.Note
|
||||
testutils.MustExec(t, db.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote")
|
||||
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote")
|
||||
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
|
@ -338,7 +330,7 @@ func TestGetNote(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
note, ok, err := GetNote(tc.note.UUID, tc.user)
|
||||
note, ok, err := GetNote(testutils.DB, tc.note.UUID, tc.user)
|
||||
if err != nil {
|
||||
t.Fatal(errors.Wrap(err, "executing"))
|
||||
}
|
||||
|
|
@ -352,14 +344,13 @@ func TestGetNote(t *testing.T) {
|
|||
func TestGetNote_nonexistent(t *testing.T) {
|
||||
user := testutils.SetupUserData()
|
||||
|
||||
db := database.DBConn
|
||||
defer testutils.ClearData()
|
||||
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
|
||||
n1UUID := "4fd19336-671e-4ff3-8f22-662b80e22edc"
|
||||
n1 := database.Note{
|
||||
|
|
@ -370,10 +361,10 @@ func TestGetNote_nonexistent(t *testing.T) {
|
|||
Deleted: false,
|
||||
Public: false,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n1), "preparing n1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
|
||||
|
||||
nonexistentUUID := "4fd19336-671e-4ff3-8f22-662b80e22edd"
|
||||
note, ok, err := GetNote(nonexistentUUID, user)
|
||||
note, ok, err := GetNote(testutils.DB, nonexistentUUID, user)
|
||||
if err != nil {
|
||||
t.Fatal(errors.Wrap(err, "executing"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/sub"
|
||||
)
|
||||
|
|
@ -66,9 +67,7 @@ func ReactivateSub(subscriptionID string, user database.User) error {
|
|||
}
|
||||
|
||||
// MarkUnsubscribed marks the user unsubscribed
|
||||
func MarkUnsubscribed(stripeCustomerID string) error {
|
||||
db := database.DBConn
|
||||
|
||||
func MarkUnsubscribed(db *gorm.DB, stripeCustomerID string) error {
|
||||
var user database.User
|
||||
if err := db.Where("stripe_customer_id = ?", stripeCustomerID).First(&user).Error; err != nil {
|
||||
return errors.Wrap(err, "finding user")
|
||||
|
|
|
|||
|
|
@ -104,8 +104,7 @@ func createDefaultRepetitionRule(user database.User, tx *gorm.DB) error {
|
|||
}
|
||||
|
||||
// CreateUser creates a user
|
||||
func CreateUser(email, password string) (database.User, error) {
|
||||
db := database.DBConn
|
||||
func CreateUser(db *gorm.DB, email, password string) (database.User, error) {
|
||||
tx := db.Begin()
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
package permissions
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/dnote/dnote/pkg/assert"
|
||||
|
|
@ -26,22 +27,26 @@ import (
|
|||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
)
|
||||
|
||||
func init() {
|
||||
func TestMain(m *testing.M) {
|
||||
testutils.InitTestDB()
|
||||
|
||||
code := m.Run()
|
||||
testutils.ClearData()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestViewNote(t *testing.T) {
|
||||
user := testutils.SetupUserData()
|
||||
anotherUser := testutils.SetupUserData()
|
||||
|
||||
db := database.DBConn
|
||||
defer testutils.ClearData()
|
||||
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
|
||||
privateNote := database.Note{
|
||||
UserID: user.ID,
|
||||
|
|
@ -50,7 +55,7 @@ func TestViewNote(t *testing.T) {
|
|||
Deleted: false,
|
||||
Public: false,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
|
||||
testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
|
||||
|
||||
publicNote := database.Note{
|
||||
UserID: user.ID,
|
||||
|
|
@ -59,7 +64,7 @@ func TestViewNote(t *testing.T) {
|
|||
Deleted: false,
|
||||
Public: true,
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote")
|
||||
testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote")
|
||||
|
||||
t.Run("owner accessing private note", func(t *testing.T) {
|
||||
result := ViewNote(&user, privateNote)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/dbconn"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stripe/stripe-go"
|
||||
|
|
@ -42,31 +43,67 @@ func init() {
|
|||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// DB is the database connection to a test database
|
||||
var DB *gorm.DB
|
||||
|
||||
// InitTestDB establishes connection pool with the test database specified by
|
||||
// the environment variable configuration and initalizes a new schema
|
||||
func InitTestDB() {
|
||||
c := database.Config{
|
||||
db := dbconn.Open(dbconn.Config{
|
||||
Host: os.Getenv("DBHost"),
|
||||
Port: os.Getenv("DBPort"),
|
||||
Name: os.Getenv("DBName"),
|
||||
User: os.Getenv("DBUser"),
|
||||
Password: os.Getenv("DBPassword"),
|
||||
})
|
||||
database.InitSchema(db)
|
||||
|
||||
DB = db
|
||||
}
|
||||
|
||||
// ClearData deletes all records from the database
|
||||
func ClearData() {
|
||||
if err := DB.Delete(&database.Book{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear books"))
|
||||
}
|
||||
if err := DB.Delete(&database.Note{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear notes"))
|
||||
}
|
||||
if err := DB.Delete(&database.Notification{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear notifications"))
|
||||
}
|
||||
if err := DB.Delete(&database.User{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear users"))
|
||||
}
|
||||
if err := DB.Delete(&database.Account{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear accounts"))
|
||||
}
|
||||
if err := DB.Delete(&database.Token{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear reset_tokens"))
|
||||
}
|
||||
if err := DB.Delete(&database.EmailPreference{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear reset_tokens"))
|
||||
}
|
||||
if err := DB.Delete(&database.Session{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear sessions"))
|
||||
}
|
||||
if err := DB.Delete(&database.Digest{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear digests"))
|
||||
}
|
||||
if err := DB.Delete(&database.RepetitionRule{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear digests"))
|
||||
}
|
||||
database.Open(c)
|
||||
database.InitSchema()
|
||||
}
|
||||
|
||||
// SetupUserData creates and returns a new user for testing purposes
|
||||
func SetupUserData() database.User {
|
||||
db := database.DBConn
|
||||
|
||||
user := database.User{
|
||||
APIKey: "test-api-key",
|
||||
Name: "user-name",
|
||||
Cloud: true,
|
||||
}
|
||||
|
||||
if err := db.Save(&user).Error; err != nil {
|
||||
if err := DB.Save(&user).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to prepare user"))
|
||||
}
|
||||
|
||||
|
|
@ -75,8 +112,6 @@ func SetupUserData() database.User {
|
|||
|
||||
// SetupAccountData creates and returns a new account for the user
|
||||
func SetupAccountData(user database.User, email, password string) database.Account {
|
||||
db := database.DBConn
|
||||
|
||||
account := database.Account{
|
||||
UserID: user.ID,
|
||||
}
|
||||
|
|
@ -90,7 +125,7 @@ func SetupAccountData(user database.User, email, password string) database.Accou
|
|||
}
|
||||
account.Password = database.ToNullString(string(hashedPassword))
|
||||
|
||||
if err := db.Save(&account).Error; err != nil {
|
||||
if err := DB.Save(&account).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to prepare account"))
|
||||
}
|
||||
|
||||
|
|
@ -99,8 +134,6 @@ func SetupAccountData(user database.User, email, password string) database.Accou
|
|||
|
||||
// SetupClassicAccountData creates and returns a new account for the user
|
||||
func SetupClassicAccountData(user database.User, email string) database.Account {
|
||||
db := database.DBConn
|
||||
|
||||
// email: alice@example.com
|
||||
// password: pass1234
|
||||
// masterKey: WbUvagj9O6o1Z+4+7COjo7Uqm4MD2QE9EWFXne8+U+8=
|
||||
|
|
@ -117,7 +150,7 @@ func SetupClassicAccountData(user database.User, email string) database.Account
|
|||
account.Email = database.ToNullString(email)
|
||||
}
|
||||
|
||||
if err := db.Save(&account).Error; err != nil {
|
||||
if err := DB.Save(&account).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to prepare account"))
|
||||
}
|
||||
|
||||
|
|
@ -126,14 +159,12 @@ func SetupClassicAccountData(user database.User, email string) database.Account
|
|||
|
||||
// SetupSession creates and returns a new user session
|
||||
func SetupSession(t *testing.T, user database.User) database.Session {
|
||||
db := database.DBConn
|
||||
|
||||
session := database.Session{
|
||||
Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
|
||||
UserID: user.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
if err := db.Save(&session).Error; err != nil {
|
||||
if err := DB.Save(&session).Error; err != nil {
|
||||
t.Fatal(errors.Wrap(err, "Failed to prepare user"))
|
||||
}
|
||||
|
||||
|
|
@ -142,56 +173,18 @@ func SetupSession(t *testing.T, user database.User) database.Session {
|
|||
|
||||
// SetupEmailPreferenceData creates and returns a new email frequency for a user
|
||||
func SetupEmailPreferenceData(user database.User, digestWeekly bool) database.EmailPreference {
|
||||
db := database.DBConn
|
||||
|
||||
frequency := database.EmailPreference{
|
||||
UserID: user.ID,
|
||||
DigestWeekly: digestWeekly,
|
||||
}
|
||||
|
||||
if err := db.Save(&frequency).Error; err != nil {
|
||||
if err := DB.Save(&frequency).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to prepare email frequency"))
|
||||
}
|
||||
|
||||
return frequency
|
||||
}
|
||||
|
||||
// ClearData deletes all records from the database
|
||||
func ClearData() {
|
||||
db := database.DBConn
|
||||
|
||||
if err := db.Delete(&database.Book{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear books"))
|
||||
}
|
||||
if err := db.Delete(&database.Note{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear notes"))
|
||||
}
|
||||
if err := db.Delete(&database.Notification{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear notifications"))
|
||||
}
|
||||
if err := db.Delete(&database.User{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear users"))
|
||||
}
|
||||
if err := db.Delete(&database.Account{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear accounts"))
|
||||
}
|
||||
if err := db.Delete(&database.Token{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear reset_tokens"))
|
||||
}
|
||||
if err := db.Delete(&database.EmailPreference{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear reset_tokens"))
|
||||
}
|
||||
if err := db.Delete(&database.Session{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear sessions"))
|
||||
}
|
||||
if err := db.Delete(&database.Digest{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear digests"))
|
||||
}
|
||||
if err := db.Delete(&database.RepetitionRule{}).Error; err != nil {
|
||||
panic(errors.Wrap(err, "Failed to clear digests"))
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPDo makes an HTTP request and returns a response
|
||||
func HTTPDo(t *testing.T, req *http.Request) *http.Response {
|
||||
hc := http.Client{
|
||||
|
|
@ -213,8 +206,6 @@ func HTTPDo(t *testing.T, req *http.Request) *http.Response {
|
|||
|
||||
// HTTPAuthDo makes an HTTP request with an appropriate authorization header for a user
|
||||
func HTTPAuthDo(t *testing.T, req *http.Request, user database.User) *http.Response {
|
||||
db := database.DBConn
|
||||
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
t.Fatal(errors.Wrap(err, "reading random bits"))
|
||||
|
|
@ -225,7 +216,7 @@ func HTTPAuthDo(t *testing.T, req *http.Request, user database.User) *http.Respo
|
|||
UserID: user.ID,
|
||||
ExpiresAt: time.Now().Add(time.Hour * 10 * 24),
|
||||
}
|
||||
if err := db.Save(&session).Error; err != nil {
|
||||
if err := DB.Save(&session).Error; err != nil {
|
||||
t.Fatal(errors.Wrap(err, "Failed to prepare user"))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
|
@ -58,8 +59,8 @@ func NewAppShell(content []byte) (AppShell, error) {
|
|||
}
|
||||
|
||||
// Execute executes the index template
|
||||
func (a AppShell) Execute(r *http.Request) ([]byte, error) {
|
||||
data, err := a.getData(r)
|
||||
func (a AppShell) Execute(r *http.Request, db *gorm.DB) ([]byte, error) {
|
||||
data, err := a.getData(db, r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting data")
|
||||
}
|
||||
|
|
@ -72,11 +73,11 @@ func (a AppShell) Execute(r *http.Request) ([]byte, error) {
|
|||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (a AppShell) getData(r *http.Request) (tmplData, error) {
|
||||
func (a AppShell) getData(db *gorm.DB, r *http.Request) (tmplData, error) {
|
||||
path := r.URL.Path
|
||||
|
||||
if ok, params := matchPath(path, notesPathRegex); ok {
|
||||
p, err := a.newNotePage(r, params[0])
|
||||
p, err := a.newNotePage(db, r, params[0])
|
||||
if err != nil {
|
||||
return tmplData{}, errors.Wrap(err, "instantiating note page")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,10 +29,6 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
testutils.InitTestDB()
|
||||
}
|
||||
|
||||
func TestAppShellExecute(t *testing.T) {
|
||||
t.Run("home", func(t *testing.T) {
|
||||
a, err := NewAppShell([]byte("<head><title>{{ .Title }}</title>{{ .MetaTags }}</head>"))
|
||||
|
|
@ -45,7 +41,7 @@ func TestAppShellExecute(t *testing.T) {
|
|||
t.Fatal(errors.Wrap(err, "preparing request"))
|
||||
}
|
||||
|
||||
b, err := a.Execute(r)
|
||||
b, err := a.Execute(r, testutils.DB)
|
||||
if err != nil {
|
||||
t.Fatal(errors.Wrap(err, "executing"))
|
||||
}
|
||||
|
|
@ -55,21 +51,20 @@ func TestAppShellExecute(t *testing.T) {
|
|||
|
||||
t.Run("note", func(t *testing.T) {
|
||||
defer testutils.ClearData()
|
||||
db := database.DBConn
|
||||
|
||||
user := testutils.SetupUserData()
|
||||
b1 := database.Book{
|
||||
UserID: user.ID,
|
||||
Label: "js",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&b1), "preparing b1")
|
||||
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
|
||||
n1 := database.Note{
|
||||
UserID: user.ID,
|
||||
BookUUID: b1.UUID,
|
||||
Public: true,
|
||||
Body: "n1 content",
|
||||
}
|
||||
testutils.MustExec(t, db.Save(&n1), "preparing note")
|
||||
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing note")
|
||||
|
||||
a, err := NewAppShell([]byte("{{ .MetaTags }}"))
|
||||
if err != nil {
|
||||
|
|
@ -82,7 +77,7 @@ func TestAppShellExecute(t *testing.T) {
|
|||
t.Fatal(errors.Wrap(err, "preparing request"))
|
||||
}
|
||||
|
||||
b, err := a.Execute(r)
|
||||
b, err := a.Execute(r, testutils.DB)
|
||||
if err != nil {
|
||||
t.Fatal(errors.Wrap(err, "executing"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/handlers"
|
||||
"github.com/dnote/dnote/pkg/server/operations"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
|
@ -51,13 +52,13 @@ type notePage struct {
|
|||
T *template.Template
|
||||
}
|
||||
|
||||
func (a AppShell) newNotePage(r *http.Request, noteUUID string) (notePage, error) {
|
||||
user, _, err := handlers.AuthWithSession(r, nil)
|
||||
func (a AppShell) newNotePage(db *gorm.DB, r *http.Request, noteUUID string) (notePage, error) {
|
||||
user, _, err := handlers.AuthWithSession(db, r, nil)
|
||||
if err != nil {
|
||||
return notePage{}, errors.Wrap(err, "authenticating with session")
|
||||
}
|
||||
|
||||
note, ok, err := operations.GetNote(noteUUID, user)
|
||||
note, ok, err := operations.GetNote(db, noteUUID, user)
|
||||
|
||||
if !ok {
|
||||
return notePage{}, ErrNotFound
|
||||
|
|
|
|||
17
pkg/server/tmpl/main_test.go
Normal file
17
pkg/server/tmpl/main_test.go
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
package tmpl
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/dnote/dnote/pkg/server/testutils"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutils.InitTestDB()
|
||||
|
||||
code := m.Run()
|
||||
testutils.ClearData()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
|
@ -24,20 +24,40 @@ import (
|
|||
|
||||
"github.com/dnote/dnote/pkg/server/handlers"
|
||||
"github.com/dnote/dnote/pkg/server/tmpl"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Context contains contents of web assets
|
||||
type Context struct {
|
||||
DB *gorm.DB
|
||||
IndexHTML []byte
|
||||
RobotsTxt []byte
|
||||
ServiceWorkerJs []byte
|
||||
StaticFileSystem http.FileSystem
|
||||
}
|
||||
|
||||
// GetRootHandler returns an HTTP handler that serves the app shell
|
||||
func GetRootHandler(b []byte) http.HandlerFunc {
|
||||
appShell, err := tmpl.NewAppShell(b)
|
||||
// Handlers are a group of web handlers
|
||||
type Handlers struct {
|
||||
GetRoot http.HandlerFunc
|
||||
GetRobots http.HandlerFunc
|
||||
GetServiceWorker http.HandlerFunc
|
||||
GetStatic http.Handler
|
||||
}
|
||||
|
||||
// Init initializes the handlers
|
||||
func Init(c Context) Handlers {
|
||||
return Handlers{
|
||||
GetRoot: getRootHandler(c),
|
||||
GetRobots: getRobotsHandler(c),
|
||||
GetServiceWorker: getSWHandler(c),
|
||||
GetStatic: getStaticHandler(c),
|
||||
}
|
||||
}
|
||||
|
||||
// getRootHandler returns an HTTP handler that serves the app shell
|
||||
func getRootHandler(c Context) http.HandlerFunc {
|
||||
appShell, err := tmpl.NewAppShell(c.IndexHTML)
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "initializing app shell"))
|
||||
}
|
||||
|
|
@ -46,7 +66,7 @@ func GetRootHandler(b []byte) http.HandlerFunc {
|
|||
// index.html must not be cached
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
|
||||
buf, err := appShell.Execute(r)
|
||||
buf, err := appShell.Execute(r, c.DB)
|
||||
if err != nil {
|
||||
if errors.Cause(err) == tmpl.ErrNotFound {
|
||||
handlers.RespondNotFound(w)
|
||||
|
|
@ -60,24 +80,25 @@ func GetRootHandler(b []byte) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// GetRobotsHandler returns an HTTP handler that serves robots.txt
|
||||
func GetRobotsHandler(b []byte) http.HandlerFunc {
|
||||
// getRobotsHandler returns an HTTP handler that serves robots.txt
|
||||
func getRobotsHandler(c Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Write(b)
|
||||
w.Write(c.RobotsTxt)
|
||||
}
|
||||
}
|
||||
|
||||
// GetSWHandler returns an HTTP handler that serves service worker
|
||||
func GetSWHandler(b []byte) http.HandlerFunc {
|
||||
// getSWHandler returns an HTTP handler that serves service worker
|
||||
func getSWHandler(c Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Content-Type", "application/javascript")
|
||||
w.Write(b)
|
||||
w.Write(c.ServiceWorkerJs)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStaticHandler returns an HTTP handler that serves static files from a filesystem
|
||||
func GetStaticHandler(root http.FileSystem) http.Handler {
|
||||
// getStaticHandler returns an HTTP handler that serves static files from a filesystem
|
||||
func getStaticHandler(c Context) http.Handler {
|
||||
root := c.StaticFileSystem
|
||||
return http.StripPrefix("/static/", http.FileServer(root))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,26 +19,27 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"github.com/dnote/dnote/pkg/server/database"
|
||||
"github.com/dnote/dnote/pkg/server/dbconn"
|
||||
"github.com/dnote/dnote/pkg/server/helpers"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
c := database.Config{
|
||||
db, err := dbconn.Open(dbconn.Config{
|
||||
Host: os.Getenv("DBHost"),
|
||||
Port: os.Getenv("DBPort"),
|
||||
Name: os.Getenv("DBName"),
|
||||
User: os.Getenv("DBUser"),
|
||||
Password: os.Getenv("DBPassword"),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
database.Open(c)
|
||||
|
||||
db := database.DBConn
|
||||
tx := db.Begin()
|
||||
|
||||
userID, err := helpers.GetDemoUserID()
|
||||
userID, err := helpers.GetDemoUserID(db)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ if [ "${WATCH-false}" == true ]; then
|
|||
while inotifywait --exclude .swp -e modify -r .; do go test ./... -cover -p 1; done;
|
||||
set -e
|
||||
else
|
||||
# go test ./... -cover -p 1
|
||||
go test ./... -cover -p 1
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
User-agent: *
|
||||
Disallow: /
|
||||
Allow: /
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue