diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 920af76f..ac44b4f3 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -11,25 +11,6 @@ jobs:
build:
runs-on: ubuntu-22.04
- services:
- postgres:
- image: postgres:14
- env:
- POSTGRES_USER: postgres
- POSTGRES_PASSWORD: postgres
- POSTGRES_DB: dnote_test
- POSTGRES_PORT: 5432
- # Wait until postgres has started
- options: >-
- --health-cmd pg_isready
- --health-interval 10s
- --health-timeout 5s
- --health-retries 5
-
- # Expose port to the host
- ports:
- - 5432:5432
-
steps:
- uses: actions/checkout@v5
- uses: actions/setup-go@v6
diff --git a/.github/workflows/release-server.yml b/.github/workflows/release-server.yml
new file mode 100644
index 00000000..75657e1e
--- /dev/null
+++ b/.github/workflows/release-server.yml
@@ -0,0 +1,91 @@
+name: Release Server
+
+on:
+ push:
+ tags:
+ - 'server-v*'
+
+jobs:
+ release:
+ runs-on: ubuntu-22.04
+ permissions:
+ contents: write
+
+ steps:
+ - uses: actions/checkout@v5
+ - uses: actions/setup-go@v6
+ with:
+ go-version: '>=1.25.0'
+ - uses: actions/setup-node@v4
+ with:
+ node-version: '20'
+
+ - name: Extract version from tag
+ id: version
+ run: |
+ TAG=${GITHUB_REF#refs/tags/server-v}
+ echo "version=$TAG" >> $GITHUB_OUTPUT
+ echo "Releasing version: $TAG"
+
+ - name: Install dependencies
+ run: make install
+
+ - name: Run tests
+ run: make test
+
+ - name: Build server
+ run: make version=${{ steps.version.outputs.version }} build-server
+
+ - name: Prepare Docker build context
+ run: |
+ VERSION="${{ steps.version.outputs.version }}"
+ cp build/server/dnote_server_${VERSION}_linux_amd64.tar.gz host/docker/
+
+ - name: Login to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKER_USERNAME }}
+ password: ${{ secrets.DOCKER_TOKEN }}
+
+ - name: Build and push Docker image
+ uses: docker/build-push-action@v6
+ with:
+ context: ./host/docker
+ push: true
+ tags: |
+ dnote/dnote:${{ steps.version.outputs.version }}
+ dnote/dnote:latest
+ build-args: |
+ tarballName=dnote_server_${{ steps.version.outputs.version }}_linux_amd64.tar.gz
+
+ - name: Create GitHub release
+ env:
+ GH_TOKEN: ${{ github.token }}
+ run: |
+ VERSION="${{ steps.version.outputs.version }}"
+ TAG="server-v${VERSION}"
+
+ # Determine if prerelease (version not matching major.minor.patch)
+ FLAGS=""
+ if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
+ FLAGS="--prerelease"
+ fi
+
+ gh release create "$TAG" \
+ build/server/*.tar.gz \
+ build/server/*_checksums.txt \
+ $FLAGS \
+ --title="$TAG" \
+ --notes="Please see the [CHANGELOG](https://github.com/dnote/dnote/blob/master/CHANGELOG.md)" \
+ --draft
+
+ - name: Push to Docker Hub
+ env:
+ DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
+ DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
+ run: |
+ VERSION="${{ steps.version.outputs.version }}"
+
+ echo "$DOCKER_TOKEN" | docker login -u "$DOCKER_USERNAME" --password-stdin
+ docker push dnote/dnote:${VERSION}
+ docker push dnote/dnote:latest
diff --git a/.gitignore b/.gitignore
index c34f93b0..57d82ddc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,5 @@
node_modules
/test
tmp
+*.db
+server
diff --git a/Makefile b/Makefile
index 09c3513f..3b91d6ae 100644
--- a/Makefile
+++ b/Makefile
@@ -48,12 +48,6 @@ test-e2e:
@(${currentDir}/scripts/e2e/test.sh)
.PHONY: test-e2e
-test-selfhost:
- @echo "==> running a smoke test for self-hosting"
-
- @${currentDir}/host/smoketest/run_test.sh ${tarballPath}
-.PHONY: test-selfhost
-
# development
dev-server:
@echo "==> running dev environment"
diff --git a/SELF_HOSTING.md b/SELF_HOSTING.md
index 9edc82e1..0c52733c 100644
--- a/SELF_HOSTING.md
+++ b/SELF_HOSTING.md
@@ -4,48 +4,33 @@ This guide documents the steps for installing the Dnote server on your own machi
## Overview
-Dnote server comes as a single binary file that you can simply download and run. It uses Postgres as the database.
+Dnote server comes as a single binary file that you can simply download and run. It uses SQLite as the database.
## Installation
-1. Install Postgres 11+.
-2. Create a `dnote` database by running `createdb dnote`
-3. Download the official Dnote server release from the [release page](https://github.com/dnote/dnote/releases).
-4. Extract the archive and move the `dnote-server` executable to `/usr/local/bin`.
+1. Download the official Dnote server release from the [release page](https://github.com/dnote/dnote/releases).
+2. Extract the archive and move the `dnote-server` executable to `/usr/local/bin`.
```bash
tar -xzf dnote-server-$version-$os.tar.gz
mv ./dnote-server /usr/local/bin
```
-4. Run Dnote
+3. Run Dnote
```bash
-GO_ENV=PRODUCTION \
-OnPremises=true \
-DBHost=localhost \
-DBPort=5432 \
-DBName=dnote \
-DBUser=$user \
-DBPassword=$password \
-WebURL=$webURL \
-SmtpHost=$SmtpHost \
-SmtpPort=$SmtpPort \
-SmtpUsername=$SmtpUsername \
-SmtpPassword=$SmtpPassword \
-DisableRegistration=false \
- dnote-server start
+dnote-server start --webUrl=$webURL
```
-Replace `$user`, `$password` with the credentials of the Postgres user that owns the `dnote` database.
-
Replace `$webURL` with the full URL to your server, without a trailing slash (e.g. `https://your.server`).
-Replace `$SmtpHost`, `SmtpPort`, `$SmtpUsername`, `$SmtpPassword` with actual values, if you would like to receive spaced repetition through email.
+Additional flags:
+- `--port`: Server port (default: `3000`)
+- `--disableRegistration`: Disable user registration (default: `false`)
+- `--logLevel`: Log level: `debug`, `info`, `warn`, or `error` (default: `info`)
+- `--appEnv`: environment (default: `PRODUCTION`)
-Replace `DisableRegistration` to `true` if you would like to disable user registrations.
-
-By default, dnote server will run on the port 3000.
+You can also use environment variables: `PORT`, `WebURL`, `DisableRegistration`, `LOG_LEVEL`, `APP_ENV`.
## Configuration
@@ -127,33 +112,31 @@ User=$user
Restart=always
RestartSec=3
WorkingDirectory=/home/$user
-ExecStart=/usr/local/bin/dnote-server start
-Environment=GO_ENV=PRODUCTION
-Environment=OnPremises=true
-Environment=DBHost=localhost
-Environment=DBPort=5432
-Environment=DBName=dnote
-Environment=DBUser=$DBUser
-Environment=DBPassword=$DBPassword
-Environment=DBSkipSSL=true
-Environment=WebURL=$WebURL
-Environment=SmtpHost=
-Environment=SmtpPort=
-Environment=SmtpUsername=
-Environment=SmtpPassword=
+ExecStart=/usr/local/bin/dnote-server start --webUrl=$WebURL
[Install]
WantedBy=multi-user.target
```
-Replace `$user`, `$WebURL`, `$DBUser`, and `$DBPassword` with the actual values.
+Replace `$user` and `$WebURL` with the actual values.
-Optionally, if you would like to send spaced repetitions throught email, populate `SmtpHost`, `SmtpPort`, `SmtpUsername`, and `SmtpPassword`.
+By default, the database will be stored at `$XDG_DATA_HOME/dnote/server.db` (typically `~/.local/share/dnote/server.db`). To use a custom location, add `--dbPath=/path/to/database.db` to the `ExecStart` command.
2. Reload the change by running `sudo systemctl daemon-reload`.
3. Enable the Daemon by running `sudo systemctl enable dnote`.`
4. Start the Daemon by running `sudo systemctl start dnote`
+### Optional: Email Support
+
+To enable sending emails, add the following environment variables to your configuration. But they are not required.
+
+- `SmtpHost` - SMTP server hostname
+- `SmtpPort` - SMTP server port
+- `SmtpUsername` - SMTP username
+- `SmtpPassword` - SMTP password
+
+For systemd, add these as additional `Environment=` lines in `/etc/systemd/system/dnote.service`.
+
### Configure clients
Let's configure Dnote clients to connect to the self-hosted web API endpoint.
@@ -166,7 +149,7 @@ The following is an example configuration:
```yaml
editor: nvim
-apiEndpoint: https://api.getdnote.com
+apiEndpoint: https://localhost:3000/api
```
Simply change the value for `apiEndpoint` to a full URL to the self-hosted instance, followed by '/api', and save the configuration file.
@@ -177,7 +160,3 @@ e.g.
editor: nvim
apiEndpoint: my-dnote-server.com/api
```
-
-#### Browser extension
-
-Navigate into the 'Settings' tab and set the values for 'API URL', and 'Web URL'.
diff --git a/go.mod b/go.mod
index 275818a9..3b87c5ec 100644
--- a/go.mod
+++ b/go.mod
@@ -3,7 +3,6 @@ module github.com/dnote/dnote
go 1.25
require (
- github.com/aymerick/douceur v0.2.0
github.com/dnote/actions v0.2.0
github.com/fatih/color v1.18.0
github.com/google/go-cmp v0.7.0
@@ -12,42 +11,34 @@ require (
github.com/gorilla/csrf v1.7.3
github.com/gorilla/mux v1.8.1
github.com/gorilla/schema v1.4.1
- github.com/joho/godotenv v1.5.1
- github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.32
github.com/pkg/errors v0.9.1
github.com/radovskyb/watcher v1.0.7
github.com/robfig/cron v1.2.0
- github.com/rubenv/sql-migrate v1.8.0
github.com/sergi/go-diff v1.3.1
github.com/spf13/cobra v1.10.1
golang.org/x/crypto v0.42.0
golang.org/x/time v0.13.0
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
gopkg.in/yaml.v2 v2.4.0
- gorm.io/driver/postgres v1.5.7
- gorm.io/gorm v1.25.7
+ gorm.io/driver/sqlite v1.6.0
+ gorm.io/gorm v1.30.0
)
require (
- github.com/PuerkitoBio/goquery v1.10.3 // indirect
- github.com/andybalholm/cascadia v1.3.3 // indirect
- github.com/go-gorp/gorp/v3 v3.1.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
- github.com/gorilla/css v1.0.1 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
- github.com/jackc/pgpassfile v1.0.0 // indirect
- github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
- github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
+ github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/spf13/pflag v1.0.10 // indirect
- golang.org/x/net v0.44.0 // indirect
+ github.com/stretchr/testify v1.8.1 // indirect
golang.org/x/sys v0.36.0 // indirect
golang.org/x/term v0.35.0 // indirect
golang.org/x/text v0.29.0 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
+ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)
diff --git a/go.sum b/go.sum
index 6d099387..e175da0d 100644
--- a/go.sum
+++ b/go.sum
@@ -1,10 +1,5 @@
-github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
-github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
-github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM=
-github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
-github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
-github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
+github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -12,12 +7,7 @@ github.com/dnote/actions v0.2.0 h1:P1ut2/QRKwfAzIIB374vN9A4IanU94C/payEocvngYo=
github.com/dnote/actions v0.2.0/go.mod h1:bBIassLhppVQdbC3iaE92SHBpM1HOVe+xZoAlj9ROxw=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
-github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs=
-github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw=
-github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
-github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
@@ -30,8 +20,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=
github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
-github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
-github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E=
@@ -40,47 +28,35 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
-github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
-github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
-github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
-github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
-github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
-github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
-github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
-github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
+github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
-github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
-github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/poy/onpar v1.1.2 h1:QaNrNiZx0+Nar5dLgTVp5mXkyoVFIbepjyEoGSnhbAY=
-github.com/poy/onpar v1.1.2/go.mod h1:6X8FLNoxyr9kkmnlqpK6LSoiOtrO6MICtWwEuWkLjzg=
github.com/radovskyb/watcher v1.0.7 h1:AYePLih6dpmS32vlHfhCeli8127LzkIgwJGcwwe8tUE=
github.com/radovskyb/watcher v1.0.7/go.mod h1:78okwvY5wPdzcb1UYnip1pvrZNIVEIh/Cm+ZuvsUYIg=
github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ=
github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
-github.com/rubenv/sql-migrate v1.8.0 h1:dXnYiJk9k3wetp7GfQbKJcPHjVJL6YK19tKj8t2Ns0o=
-github.com/rubenv/sql-migrate v1.8.0/go.mod h1:F2bGFBwCU+pnmbtNYDeKvSuvL6lBVtXDXUUv5t+u1qw=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
@@ -90,88 +66,24 @@ github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
-github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
-github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
-golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
-golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
-golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
-golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
-golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
-golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
-golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
-golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
-golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
-golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
-golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
-golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
-golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
-golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
-golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
-golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
-golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
-golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
-golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
-golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
-golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
-golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
-golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
-golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
-golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
-golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
-golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
-golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
-golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
-golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
-golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
-golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
-golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
@@ -187,7 +99,7 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
-gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
-gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
-gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
+gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
+gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
+gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
+gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
diff --git a/host/docker/compose.yml b/host/docker/compose.yml
index 446e841d..d2c2e5ed 100644
--- a/host/docker/compose.yml
+++ b/host/docker/compose.yml
@@ -1,35 +1,14 @@
version: "3"
services:
- postgres:
- image: postgres:14-alpine
- environment:
- POSTGRES_USER: dnote
- POSTGRES_PASSWORD: dnote
- POSTGRES_DB: dnote
- volumes:
- - ./dnote_data:/var/lib/postgresql/data
- restart: always
-
dnote:
image: dnote/dnote:latest
environment:
- GO_ENV: PRODUCTION
- DBSkipSSL: "true"
- DBHost: postgres
- DBPort: 5432
- DBName: dnote
- DBUser: dnote
- DBPassword: dnote
+ APP_ENV: PRODUCTION
WebURL: localhost:3000
- OnPremises: "true"
- SmtpHost:
- SmtpPort:
- SmtpUsername:
- SmtpPassword:
DisableRegistration: "false"
ports:
- 3000:3000
- depends_on:
- - postgres
+ volumes:
+ - ./dnote_data:/data
restart: always
diff --git a/host/docker/entrypoint.sh b/host/docker/entrypoint.sh
index 8fb62de9..0fee185c 100755
--- a/host/docker/entrypoint.sh
+++ b/host/docker/entrypoint.sh
@@ -1,25 +1,6 @@
#!/bin/sh
-wait_for_db() {
- HOST=${DBHost:-postgres}
- PORT=${DBPort:-5432}
- echo "Waiting for the database connection..."
-
- attempts=0
- max_attempts=10
- while [ $attempts -lt $max_attempts ]; do
- nc -z "${HOST}" "${PORT}" 2>/dev/null && break
- echo "Waiting for db at ${HOST}:${PORT}..."
- sleep 5
- attempts=$((attempts+1))
- done
-
- if [ $attempts -eq $max_attempts ]; then
- echo "Timed out while waiting for db at ${HOST}:${PORT}"
- exit 1
- fi
-}
-
-wait_for_db
+# Set default DBPath to /data if not specified
+export DBPath=${DBPath:-/data/dnote.db}
exec "$@"
diff --git a/host/smoketest/.gitignore b/host/smoketest/.gitignore
deleted file mode 100644
index 463ebfd4..00000000
--- a/host/smoketest/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-/volume
diff --git a/host/smoketest/README.md b/host/smoketest/README.md
deleted file mode 100644
index a428896a..00000000
--- a/host/smoketest/README.md
+++ /dev/null
@@ -1,9 +0,0 @@
-This directory contains a smoke test for running a self-hosted instance using a virtual machine.
-
-## Instruction
-
-The following script will set up a test environment in Vagrant and run the test.
-
-```
-./run_test.sh
-```
diff --git a/host/smoketest/Vagrantfile b/host/smoketest/Vagrantfile
deleted file mode 100644
index 2eaee5dd..00000000
--- a/host/smoketest/Vagrantfile
+++ /dev/null
@@ -1,9 +0,0 @@
-# -*- mode: ruby -*-
-
-Vagrant.configure("2") do |config|
- config.vm.box = "ubuntu/jammy64"
- config.vm.synced_folder './volume', '/vagrant'
- config.vm.network "forwarded_port", guest: 2300, host: 2300
-
- config.vm.provision 'shell', path: './setup.sh', privileged: false
-end
diff --git a/host/smoketest/run_test.sh b/host/smoketest/run_test.sh
deleted file mode 100755
index 8137dfcd..00000000
--- a/host/smoketest/run_test.sh
+++ /dev/null
@@ -1,42 +0,0 @@
-#!/usr/bin/env bash
-# run_test.sh builds a fresh server image, and mounts it on a fresh
-# virtual machine and runs a smoke test. If a tarball path is not provided,
-# this script builds a new version and uses it.
-set -ex
-
-# tarballPath is an absolute path to a release tarball containing the dnote server.
-tarballPath=$1
-
-dir=$(dirname "${BASH_SOURCE[0]}")
-projectDir="$dir/../.."
-
-# build
-if [ -z "$tarballPath" ]; then
- pushd "$projectDir"
- make version=integration_test build-server
- popd
- tarballPath="$projectDir/build/server/dnote_server_integration_test_linux_amd64.tar.gz"
-fi
-
-pushd "$dir"
-
-# start a virtual machine
-volume="$dir/volume"
-rm -rf "$volume"
-mkdir -p "$volume"
-cp "$tarballPath" "$volume"
-cp "$dir/testsuite.sh" "$volume"
-
-vagrant up
-
-# run tests
-set +e
-if ! vagrant ssh -c "/vagrant/testsuite.sh"; then
- echo "Test failed. Please see the output."
- vagrant halt
- exit 1
-fi
-set -e
-
-vagrant halt
-popd
diff --git a/host/smoketest/setup.sh b/host/smoketest/setup.sh
deleted file mode 100755
index 575b7fd8..00000000
--- a/host/smoketest/setup.sh
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/usr/bin/env bash
-set -ex
-
-sudo apt-get install wget ca-certificates
-wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
-sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt/ `lsb_release -cs`-pgdg main" >> /etc/apt/sources.list.d/pgdg.list'
-
-sudo apt-get update
-sudo apt-get install -y postgresql-14
-
-# set up database
-sudo usermod -a -G sudo postgres
-cd /var/lib/postgresql
-sudo -u postgres createdb dnote
-sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'postgres';"
-
-# allow connection from host and allow to connect without password
-sudo sed -i "/port*/a listen_addresses = '*'" /etc/postgresql/14/main/postgresql.conf
-sudo sed -i 's/host.*all.*.all.*md5/# &/' /etc/postgresql/14/main/pg_hba.conf
-sudo sed -i "$ a host all all all trust" /etc/postgresql/14/main/pg_hba.conf
-sudo service postgresql restart
diff --git a/host/smoketest/testsuite.sh b/host/smoketest/testsuite.sh
deleted file mode 100755
index 4fa05af3..00000000
--- a/host/smoketest/testsuite.sh
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/env bash
-# testsuite.sh runs the smoke tests for a self-hosted instance.
-# It is meant to be run inside a virtual machine which has been
-# set up by an entry script.
-set -eux
-
-echo 'Running a smoke test'
-
-cd /var/lib/postgresql
-sudo -u postgres dropdb dnote
-sudo -u postgres createdb dnote
-
-cd /vagrant
-
-tar -xvf dnote_server_integration_test_linux_amd64.tar.gz
-
-GO_ENV=PRODUCTION \
- DBHost=localhost \
- DBPort=5432 \
- DBName=dnote \
- DBUser=postgres \
- DBPassword=postgres \
- WebURL=localhost:3000 \
- ./dnote-server -port 2300 start & sleep 3
-
-assert_http_status() {
- url=$1
- expected=$2
-
- echo "======== [TEST CASE] asserting response status code for $url ========"
-
- got=$(curl --write-out %"{http_code}" --silent --output /dev/null "$url")
-
- if [ "$got" != "$expected" ]; then
- echo "======== ASSERTION FAILED ========"
- echo "status code for $url: expected: $expected got: $got"
- echo "=================================="
- exit 1
- fi
-}
-
-assert_http_status http://localhost:2300 "302"
-assert_http_status http://localhost:2300/health "200"
-
-echo "======== [SUCCESS] TEST PASSED! ========"
diff --git a/pkg/assert/assert.go b/pkg/assert/assert.go
index 94c7b411..000d7b48 100644
--- a/pkg/assert/assert.go
+++ b/pkg/assert/assert.go
@@ -22,7 +22,7 @@ package assert
import (
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"reflect"
"runtime/debug"
@@ -138,7 +138,7 @@ func EqualJSON(t *testing.T, a, b, message string) {
// expected
func StatusCodeEquals(t *testing.T, res *http.Response, expected int, message string) {
if res.StatusCode != expected {
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
diff --git a/pkg/cli/client/client.go b/pkg/cli/client/client.go
index 59cfdb59..b90ddaca 100644
--- a/pkg/cli/client/client.go
+++ b/pkg/cli/client/client.go
@@ -23,7 +23,7 @@ package client
import (
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"net/url"
"strconv"
@@ -95,7 +95,7 @@ func checkRespErr(res *http.Response) error {
return nil
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
return errors.Wrapf(err, "server responded with %d but client could not read the response body", res.StatusCode)
}
@@ -169,7 +169,7 @@ func GetSyncState(ctx context.DnoteCtx) (GetSyncStateResp, error) {
return ret, errors.Wrap(err, "constructing http request")
}
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
return ret, errors.Wrap(err, "reading the response body")
}
@@ -233,7 +233,7 @@ func GetSyncFragment(ctx context.DnoteCtx, afterUSN int) (GetSyncFragmentResp, e
path := fmt.Sprintf("/v3/sync/fragment?%s", queryStr)
res, err := doAuthorizedReq(ctx, "GET", path, "", nil)
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
return GetSyncFragmentResp{}, errors.Wrap(err, "reading the response body")
}
diff --git a/pkg/cli/cmd/edit/note.go b/pkg/cli/cmd/edit/note.go
index ba8968c3..f6019778 100644
--- a/pkg/cli/cmd/edit/note.go
+++ b/pkg/cli/cmd/edit/note.go
@@ -20,7 +20,7 @@ package edit
import (
"database/sql"
- "io/ioutil"
+ "os"
"strconv"
"github.com/dnote/dnote/pkg/cli/context"
@@ -45,7 +45,7 @@ func waitEditorNoteContent(ctx context.DnoteCtx, note database.Note) (string, er
return "", errors.Wrap(err, "getting temporarily content file path")
}
- if err := ioutil.WriteFile(fpath, []byte(note.Body), 0644); err != nil {
+ if err := os.WriteFile(fpath, []byte(note.Body), 0644); err != nil {
return "", errors.Wrap(err, "preparing tmp content file")
}
diff --git a/pkg/cli/config/config.go b/pkg/cli/config/config.go
index 3ef76a6b..b0faa9cf 100644
--- a/pkg/cli/config/config.go
+++ b/pkg/cli/config/config.go
@@ -20,7 +20,7 @@ package config
import (
"fmt"
- "io/ioutil"
+ "os"
"github.com/dnote/dnote/pkg/cli/consts"
"github.com/dnote/dnote/pkg/cli/context"
@@ -66,7 +66,7 @@ func Read(ctx context.DnoteCtx) (Config, error) {
var ret Config
configPath := GetPath(ctx)
- b, err := ioutil.ReadFile(configPath)
+ b, err := os.ReadFile(configPath)
if err != nil {
return ret, errors.Wrap(err, "reading config file")
}
@@ -88,7 +88,7 @@ func Write(ctx context.DnoteCtx, cf Config) error {
return errors.Wrap(err, "marshalling config into YAML")
}
- err = ioutil.WriteFile(path, b, 0644)
+ err = os.WriteFile(path, b, 0644)
if err != nil {
return errors.Wrap(err, "writing the config file")
}
diff --git a/pkg/cli/infra/init.go b/pkg/cli/infra/init.go
index cecd022f..72286cad 100644
--- a/pkg/cli/infra/init.go
+++ b/pkg/cli/infra/init.go
@@ -32,11 +32,11 @@ import (
"github.com/dnote/dnote/pkg/cli/consts"
"github.com/dnote/dnote/pkg/cli/context"
"github.com/dnote/dnote/pkg/cli/database"
- "github.com/dnote/dnote/pkg/cli/dirs"
"github.com/dnote/dnote/pkg/cli/log"
"github.com/dnote/dnote/pkg/cli/migrate"
"github.com/dnote/dnote/pkg/cli/utils"
"github.com/dnote/dnote/pkg/clock"
+ "github.com/dnote/dnote/pkg/dirs"
"github.com/pkg/errors"
"github.com/spf13/cobra"
)
diff --git a/pkg/cli/migrate/legacy.go b/pkg/cli/migrate/legacy.go
index d0dab97b..dc5441c7 100644
--- a/pkg/cli/migrate/legacy.go
+++ b/pkg/cli/migrate/legacy.go
@@ -23,7 +23,6 @@ package migrate
import (
"encoding/json"
"fmt"
- "io/ioutil"
"os"
"time"
@@ -232,7 +231,7 @@ func readSchema(ctx context.DnoteCtx) (schema, error) {
path := getSchemaPath(ctx)
- b, err := ioutil.ReadFile(path)
+ b, err := os.ReadFile(path)
if err != nil {
return ret, errors.Wrap(err, "Failed to read schema file")
}
@@ -252,7 +251,7 @@ func writeSchema(ctx context.DnoteCtx, s schema) error {
return errors.Wrap(err, "Failed to marshal schema into yaml")
}
- if err := ioutil.WriteFile(path, d, 0644); err != nil {
+ if err := os.WriteFile(path, d, 0644); err != nil {
return errors.Wrap(err, "Failed to write schema file")
}
@@ -504,7 +503,7 @@ func migrateToV1(ctx context.DnoteCtx) error {
func migrateToV2(ctx context.DnoteCtx) error {
notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote)
- b, err := ioutil.ReadFile(notePath)
+ b, err := os.ReadFile(notePath)
if err != nil {
return errors.Wrap(err, "Failed to read the note file")
}
@@ -548,7 +547,7 @@ func migrateToV2(ctx context.DnoteCtx) error {
return errors.Wrap(err, "Failed to marshal new dnote into JSON")
}
- err = ioutil.WriteFile(notePath, d, 0644)
+ err = os.WriteFile(notePath, d, 0644)
if err != nil {
return errors.Wrap(err, "Failed to write the new dnote into the file")
}
@@ -561,7 +560,7 @@ func migrateToV3(ctx context.DnoteCtx) error {
notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote)
actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote)
- b, err := ioutil.ReadFile(notePath)
+ b, err := os.ReadFile(notePath)
if err != nil {
return errors.Wrap(err, "Failed to read the note file")
}
@@ -615,7 +614,7 @@ func migrateToV3(ctx context.DnoteCtx) error {
return errors.Wrap(err, "Failed to marshal actions into JSON")
}
- err = ioutil.WriteFile(actionsPath, a, 0644)
+ err = os.WriteFile(actionsPath, a, 0644)
if err != nil {
return errors.Wrap(err, "Failed to write the actions into a file")
}
@@ -647,7 +646,7 @@ func getEditorCommand() string {
func migrateToV4(ctx context.DnoteCtx) error {
configPath := fmt.Sprintf("%s/dnoterc", ctx.Paths.LegacyDnote)
- b, err := ioutil.ReadFile(configPath)
+ b, err := os.ReadFile(configPath)
if err != nil {
return errors.Wrap(err, "Failed to read the config file")
}
@@ -668,7 +667,7 @@ func migrateToV4(ctx context.DnoteCtx) error {
return errors.Wrap(err, "Failed to marshal config into JSON")
}
- err = ioutil.WriteFile(configPath, data, 0644)
+ err = os.WriteFile(configPath, data, 0644)
if err != nil {
return errors.Wrap(err, "Failed to write the config into a file")
}
@@ -680,7 +679,7 @@ func migrateToV4(ctx context.DnoteCtx) error {
func migrateToV5(ctx context.DnoteCtx) error {
actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote)
- b, err := ioutil.ReadFile(actionsPath)
+ b, err := os.ReadFile(actionsPath)
if err != nil {
return errors.Wrap(err, "reading the actions file")
}
@@ -738,7 +737,7 @@ func migrateToV5(ctx context.DnoteCtx) error {
if err != nil {
return errors.Wrap(err, "marshalling result into JSON")
}
- err = ioutil.WriteFile(actionsPath, a, 0644)
+ err = os.WriteFile(actionsPath, a, 0644)
if err != nil {
return errors.Wrap(err, "writing the result into a file")
}
@@ -750,7 +749,7 @@ func migrateToV5(ctx context.DnoteCtx) error {
func migrateToV6(ctx context.DnoteCtx) error {
notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote)
- b, err := ioutil.ReadFile(notePath)
+ b, err := os.ReadFile(notePath)
if err != nil {
return errors.Wrap(err, "Failed to read the note file")
}
@@ -791,7 +790,7 @@ func migrateToV6(ctx context.DnoteCtx) error {
return errors.Wrap(err, "Failed to marshal new dnote into JSON")
}
- err = ioutil.WriteFile(notePath, d, 0644)
+ err = os.WriteFile(notePath, d, 0644)
if err != nil {
return errors.Wrap(err, "Failed to write the new dnote into the file")
}
@@ -805,7 +804,7 @@ func migrateToV6(ctx context.DnoteCtx) error {
func migrateToV7(ctx context.DnoteCtx) error {
actionPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote)
- b, err := ioutil.ReadFile(actionPath)
+ b, err := os.ReadFile(actionPath)
if err != nil {
return errors.Wrap(err, "reading actions file")
}
@@ -857,7 +856,7 @@ func migrateToV7(ctx context.DnoteCtx) error {
return errors.Wrap(err, "marshalling new actions")
}
- err = ioutil.WriteFile(actionPath, d, 0644)
+ err = os.WriteFile(actionPath, d, 0644)
if err != nil {
return errors.Wrap(err, "writing new actions to a file")
}
@@ -874,7 +873,7 @@ func migrateToV8(ctx context.DnoteCtx) error {
// 1. Migrate the the dnote file
dnoteFilePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote)
- b, err := ioutil.ReadFile(dnoteFilePath)
+ b, err := os.ReadFile(dnoteFilePath)
if err != nil {
return errors.Wrap(err, "reading the notes")
}
@@ -914,7 +913,7 @@ func migrateToV8(ctx context.DnoteCtx) error {
// 2. Migrate the actions file
actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote)
- b, err = ioutil.ReadFile(actionsPath)
+ b, err = os.ReadFile(actionsPath)
if err != nil {
return errors.Wrap(err, "reading the actions")
}
@@ -939,7 +938,7 @@ func migrateToV8(ctx context.DnoteCtx) error {
// 3. Migrate the timestamps file
timestampsPath := fmt.Sprintf("%s/timestamps", ctx.Paths.LegacyDnote)
- b, err = ioutil.ReadFile(timestampsPath)
+ b, err = os.ReadFile(timestampsPath)
if err != nil {
return errors.Wrap(err, "reading the timestamps")
}
diff --git a/pkg/cli/migrate/legacy_test.go b/pkg/cli/migrate/legacy_test.go
index 211feee7..00ebb7d7 100644
--- a/pkg/cli/migrate/legacy_test.go
+++ b/pkg/cli/migrate/legacy_test.go
@@ -21,7 +21,6 @@ package migrate
import (
"encoding/json"
"fmt"
- "io/ioutil"
"os"
"path/filepath"
"testing"
@@ -65,7 +64,7 @@ func TestMigrateToV1(t *testing.T) {
if err != nil {
panic(errors.Wrap(err, "Failed to get absolute YAML path").Error())
}
- ioutil.WriteFile(yamlPath, []byte{}, 0644)
+ os.WriteFile(yamlPath, []byte{}, 0644)
// execute
if err := migrateToV1(ctx); err != nil {
diff --git a/pkg/cli/migrate/migrate_test.go b/pkg/cli/migrate/migrate_test.go
index bfda7768..cd2619bd 100644
--- a/pkg/cli/migrate/migrate_test.go
+++ b/pkg/cli/migrate/migrate_test.go
@@ -21,8 +21,8 @@ package migrate
import (
"encoding/json"
"fmt"
- "io/ioutil"
"net/http"
+ "os"
"net/http/httptest"
"testing"
"time"
@@ -1079,7 +1079,7 @@ func TestLocalMigration12(t *testing.T) {
data := []byte("editor: vim")
path := fmt.Sprintf("%s/%s/dnoterc", ctx.Paths.Config, consts.DnoteDirName)
- if err := ioutil.WriteFile(path, data, 0644); err != nil {
+ if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatal(errors.Wrap(err, "Failed to write schema file"))
}
@@ -1090,7 +1090,7 @@ func TestLocalMigration12(t *testing.T) {
}
// test
- b, err := ioutil.ReadFile(path)
+ b, err := os.ReadFile(path)
if err != nil {
t.Fatal(errors.Wrap(err, "reading config"))
}
@@ -1117,7 +1117,7 @@ func TestLocalMigration13(t *testing.T) {
data := []byte("editor: vim\napiEndpoint: https://test.com/api")
path := fmt.Sprintf("%s/%s/dnoterc", ctx.Paths.Config, consts.DnoteDirName)
- if err := ioutil.WriteFile(path, data, 0644); err != nil {
+ if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatal(errors.Wrap(err, "Failed to write schema file"))
}
@@ -1128,7 +1128,7 @@ func TestLocalMigration13(t *testing.T) {
}
// test
- b, err := ioutil.ReadFile(path)
+ b, err := os.ReadFile(path)
if err != nil {
t.Fatal(errors.Wrap(err, "reading config"))
}
diff --git a/pkg/cli/testutils/main.go b/pkg/cli/testutils/main.go
index bdcc00d2..afc12f59 100644
--- a/pkg/cli/testutils/main.go
+++ b/pkg/cli/testutils/main.go
@@ -23,7 +23,6 @@ import (
"bytes"
"encoding/json"
"io"
- "io/ioutil"
"os"
"os/exec"
"path/filepath"
@@ -81,7 +80,7 @@ func WriteFile(ctx context.DnoteCtx, content []byte, filename string) {
panic(err)
}
- if err := ioutil.WriteFile(dp, content, 0644); err != nil {
+ if err := os.WriteFile(dp, content, 0644); err != nil {
panic(err)
}
}
@@ -90,7 +89,7 @@ func WriteFile(ctx context.DnoteCtx, content []byte, filename string) {
func ReadFile(ctx context.DnoteCtx, filename string) []byte {
path := filepath.Join(ctx.Paths.LegacyDnote, filename)
- b, err := ioutil.ReadFile(path)
+ b, err := os.ReadFile(path)
if err != nil {
panic(err)
}
@@ -101,7 +100,7 @@ func ReadFile(ctx context.DnoteCtx, filename string) []byte {
// ReadJSON reads JSON fixture to the struct at the destination address
func ReadJSON(path string, destination interface{}) {
var dat []byte
- dat, err := ioutil.ReadFile(path)
+ dat, err := os.ReadFile(path)
if err != nil {
panic(errors.Wrap(err, "Failed to load fixture payload"))
}
diff --git a/pkg/cli/ui/editor.go b/pkg/cli/ui/editor.go
index afafcaa3..3501e7ae 100644
--- a/pkg/cli/ui/editor.go
+++ b/pkg/cli/ui/editor.go
@@ -21,7 +21,6 @@ package ui
import (
"fmt"
- "io/ioutil"
"os"
"os/exec"
"strings"
@@ -122,7 +121,7 @@ func GetEditorInput(ctx context.DnoteCtx, fpath string) (string, error) {
return "", errors.Wrap(err, "waiting for the editor")
}
- b, err := ioutil.ReadFile(fpath)
+ b, err := os.ReadFile(fpath)
if err != nil {
return "", errors.Wrap(err, "reading the temporary content file")
}
diff --git a/pkg/cli/utils/files.go b/pkg/cli/utils/files.go
index 1335fc2b..b4b2d1df 100644
--- a/pkg/cli/utils/files.go
+++ b/pkg/cli/utils/files.go
@@ -20,7 +20,6 @@ package utils
import (
"io"
- "io/ioutil"
"os"
"path/filepath"
@@ -35,7 +34,7 @@ func ReadFileAbs(relpath string) []byte {
panic(err)
}
- b, err := ioutil.ReadFile(fp)
+ b, err := os.ReadFile(fp)
if err != nil {
panic(err)
}
@@ -80,7 +79,7 @@ func CopyDir(src, dest string) error {
return errors.Wrap(err, "creating destination")
}
- entries, err := ioutil.ReadDir(src)
+ entries, err := os.ReadDir(src)
if err != nil {
return errors.Wrap(err, "reading the directory listing for the input")
}
diff --git a/pkg/cli/dirs/dirs.go b/pkg/dirs/dirs.go
similarity index 100%
rename from pkg/cli/dirs/dirs.go
rename to pkg/dirs/dirs.go
diff --git a/pkg/cli/dirs/dirs_test.go b/pkg/dirs/dirs_test.go
similarity index 96%
rename from pkg/cli/dirs/dirs_test.go
rename to pkg/dirs/dirs_test.go
index 0e5b0394..6fe3d1e0 100644
--- a/pkg/cli/dirs/dirs_test.go
+++ b/pkg/dirs/dirs_test.go
@@ -19,7 +19,6 @@
package dirs
import (
- "os"
"testing"
"github.com/dnote/dnote/pkg/assert"
@@ -34,7 +33,7 @@ type envTestCase struct {
func testCustomDirs(t *testing.T, testCases []envTestCase) {
for _, tc := range testCases {
- os.Setenv(tc.envKey, tc.envVal)
+ t.Setenv(tc.envKey, tc.envVal)
Reload()
diff --git a/pkg/cli/dirs/dirs_unix.go b/pkg/dirs/dirs_unix.go
similarity index 100%
rename from pkg/cli/dirs/dirs_unix.go
rename to pkg/dirs/dirs_unix.go
diff --git a/pkg/cli/dirs/dirs_unix_test.go b/pkg/dirs/dirs_unix_test.go
similarity index 100%
rename from pkg/cli/dirs/dirs_unix_test.go
rename to pkg/dirs/dirs_unix_test.go
diff --git a/pkg/cli/dirs/dirs_windows.go b/pkg/dirs/dirs_windows.go
similarity index 100%
rename from pkg/cli/dirs/dirs_windows.go
rename to pkg/dirs/dirs_windows.go
diff --git a/pkg/cli/dirs/dirs_windows_test.go b/pkg/dirs/dirs_windows_test.go
similarity index 100%
rename from pkg/cli/dirs/dirs_windows_test.go
rename to pkg/dirs/dirs_windows_test.go
diff --git a/pkg/e2e/server_test.go b/pkg/e2e/server_test.go
new file mode 100644
index 00000000..52a2d311
--- /dev/null
+++ b/pkg/e2e/server_test.go
@@ -0,0 +1,183 @@
+/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
+ *
+ * This file is part of Dnote.
+ *
+ * Dnote is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Dnote is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with Dnote. If not, see .
+ */
+
+package main
+
+import (
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/dnote/dnote/pkg/assert"
+ "gorm.io/driver/sqlite"
+ "gorm.io/gorm"
+)
+
+var testServerBinary string
+
+func init() {
+ // Build server binary in temp directory
+ tmpDir := os.TempDir()
+ testServerBinary = fmt.Sprintf("%s/dnote-test-server", tmpDir)
+ buildCmd := exec.Command("go", "build", "-tags", "fts5", "-o", testServerBinary, "../server")
+ if out, err := buildCmd.CombinedOutput(); err != nil {
+ panic(fmt.Sprintf("failed to build server: %v\n%s", err, out))
+ }
+}
+
+func TestServerStart(t *testing.T) {
+ tmpDB := t.TempDir() + "/test.db"
+ port := "13456" // Use different port to avoid conflicts with main test server
+
+ // Start server in background
+ cmd := exec.Command(testServerBinary, "start", "-port", port)
+ cmd.Env = append(os.Environ(),
+ "DBPath="+tmpDB,
+ "WebURL=http://localhost:"+port,
+ "APP_ENV=PRODUCTION",
+ )
+
+ if err := cmd.Start(); err != nil {
+ t.Fatalf("failed to start server: %v", err)
+ }
+
+ // Ensure cleanup
+ cleanup := func() {
+ if cmd.Process != nil {
+ cmd.Process.Kill()
+ cmd.Wait() // Wait for process to fully exit
+ }
+ }
+ defer cleanup()
+
+ // Wait for server to start and migrations to run
+ time.Sleep(3 * time.Second)
+
+ // Verify server responds to health check
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%s/health", port))
+ if err != nil {
+ t.Fatalf("failed to reach server health endpoint: %v", err)
+ }
+ defer resp.Body.Close()
+
+ assert.Equal(t, resp.StatusCode, 200, "health endpoint should return 200")
+
+ // Kill server before checking database to avoid locks
+ cleanup()
+
+ // Verify database file was created
+ if _, err := os.Stat(tmpDB); os.IsNotExist(err) {
+ t.Fatalf("database file was not created at %s", tmpDB)
+ }
+
+ // Verify migrations ran by checking database
+ db, err := gorm.Open(sqlite.Open(tmpDB), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open test database: %v", err)
+ }
+
+ // Verify migrations ran
+ var count int64
+ if err := db.Raw("SELECT COUNT(*) FROM schema_migrations").Scan(&count).Error; err != nil {
+ t.Fatalf("schema_migrations table not found: %v", err)
+ }
+ if count == 0 {
+ t.Fatal("no migrations were run")
+ }
+
+ // Verify FTS table exists and is functional
+ if err := db.Exec("SELECT * FROM notes_fts LIMIT 1").Error; err != nil {
+ t.Fatalf("notes_fts table not found or not functional: %v", err)
+ }
+}
+
+func TestServerVersion(t *testing.T) {
+ cmd := exec.Command("go", "run", "-tags", "fts5", "../server", "version")
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Fatalf("version command failed: %v", err)
+ }
+
+ outputStr := string(output)
+ if !strings.Contains(outputStr, "dnote-server-") {
+ t.Errorf("expected version output to contain 'dnote-server-', got: %s", outputStr)
+ }
+}
+
+func TestServerRootCommand(t *testing.T) {
+ cmd := exec.Command(testServerBinary)
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Fatalf("server command failed: %v", err)
+ }
+
+ outputStr := string(output)
+ assert.Equal(t, strings.Contains(outputStr, "Dnote server - a simple command line notebook"), true, "output should contain description")
+ assert.Equal(t, strings.Contains(outputStr, "start: Start the server"), true, "output should contain start command")
+ assert.Equal(t, strings.Contains(outputStr, "version: Print the version"), true, "output should contain version command")
+}
+
+func TestServerStartHelp(t *testing.T) {
+ cmd := exec.Command(testServerBinary, "start", "--help")
+ output, _ := cmd.CombinedOutput()
+
+ outputStr := string(output)
+ assert.Equal(t, strings.Contains(outputStr, "dnote-server start [flags]"), true, "output should contain usage")
+ assert.Equal(t, strings.Contains(outputStr, "-appEnv"), true, "output should contain appEnv flag")
+ assert.Equal(t, strings.Contains(outputStr, "-port"), true, "output should contain port flag")
+ assert.Equal(t, strings.Contains(outputStr, "-webUrl"), true, "output should contain webUrl flag")
+ assert.Equal(t, strings.Contains(outputStr, "-dbPath"), true, "output should contain dbPath flag")
+ assert.Equal(t, strings.Contains(outputStr, "-disableRegistration"), true, "output should contain disableRegistration flag")
+}
+
+func TestServerStartInvalidConfig(t *testing.T) {
+ cmd := exec.Command(testServerBinary, "start")
+ // Clear WebURL env var so validation fails
+ cmd.Env = []string{}
+
+ output, err := cmd.CombinedOutput()
+
+ // Should exit with non-zero status
+ if err == nil {
+ t.Fatal("expected command to fail with invalid config")
+ }
+
+ outputStr := string(output)
+ assert.Equal(t, strings.Contains(outputStr, "Error:"), true, "output should contain error message")
+ assert.Equal(t, strings.Contains(outputStr, "Invalid WebURL"), true, "output should mention invalid WebURL")
+ assert.Equal(t, strings.Contains(outputStr, "dnote-server start [flags]"), true, "output should show usage")
+ assert.Equal(t, strings.Contains(outputStr, "-webUrl"), true, "output should show flags")
+}
+
+func TestServerUnknownCommand(t *testing.T) {
+ cmd := exec.Command(testServerBinary, "unknown")
+ output, err := cmd.CombinedOutput()
+
+ // Should exit with non-zero status
+ if err == nil {
+ t.Fatal("expected command to fail with unknown command")
+ }
+
+ outputStr := string(output)
+ assert.Equal(t, strings.Contains(outputStr, "Unknown command"), true, "output should contain unknown command message")
+ assert.Equal(t, strings.Contains(outputStr, "Dnote server - a simple command line notebook"), true, "output should show help")
+}
diff --git a/pkg/e2e/sync_test.go b/pkg/e2e/sync_test.go
index d51cbed4..2f3234f2 100644
--- a/pkg/e2e/sync_test.go
+++ b/pkg/e2e/sync_test.go
@@ -22,7 +22,7 @@ import (
"bytes"
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"log"
"net/http"
"net/http/httptest"
@@ -39,16 +39,17 @@ import (
clitest "github.com/dnote/dnote/pkg/cli/testutils"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/controllers"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/mailer"
apitest "github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
+ "gorm.io/gorm"
)
var cliBinaryName string
var server *httptest.Server
+var serverDb *gorm.DB
var serverTime = time.Date(2017, time.March, 14, 21, 15, 0, 0, time.UTC)
var tmpDirPath string
@@ -82,22 +83,22 @@ func clearTmp(t *testing.T) {
}
func TestMain(m *testing.M) {
- // Set up server database
- apitest.InitTestDB()
+ // Set up server database - use file-based DB for e2e tests
+ dbPath := fmt.Sprintf("%s/server.db", testDir)
+ serverDb = apitest.InitDB(dbPath)
mockClock := clock.NewMock()
mockClock.SetNow(serverTime)
+ a := app.NewTest()
+ a.Clock = mockClock
+ a.EmailTemplates = mailer.Templates{}
+ a.EmailBackend = &apitest.MockEmailbackendImplementation{}
+ a.DB = serverDb
+ a.WebURL = os.Getenv("WebURL")
+
var err error
- server, err = controllers.NewServer(&app.App{
- Clock: mockClock,
- EmailTemplates: mailer.Templates{},
- EmailBackend: &apitest.MockEmailbackendImplementation{},
- DB: apitest.DB,
- Config: config.Config{
- WebURL: os.Getenv("WebURL"),
- },
- })
+ server, err = controllers.NewServer(&a)
if err != nil {
panic(errors.Wrap(err, "initializing router"))
}
@@ -124,11 +125,11 @@ func TestMain(m *testing.M) {
// helpers
func setupUser(t *testing.T, ctx *context.DnoteCtx) database.User {
- user := apitest.SetupUserData()
- apitest.SetupAccountData(user, "alice@example.com", "pass1234")
+ user := apitest.SetupUserData(serverDb)
+ apitest.SetupAccountData(serverDb, user, "alice@example.com", "pass1234")
// log in the user in CLI
- session := apitest.SetupSession(t, user)
+ session := apitest.SetupSession(serverDb, user)
cliDatabase.MustExec(t, "inserting session_key", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKey, session.Key)
cliDatabase.MustExec(t, "inserting session_key_expiry", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKeyExpiry, session.ExpiresAt.Unix())
@@ -184,9 +185,9 @@ func doHTTPReq(t *testing.T, method, path, payload, message string, user databas
panic(errors.Wrap(err, "constructing http request"))
}
- res := apitest.HTTPAuthDo(t, req, user)
+ res := apitest.HTTPAuthDo(t, serverDb, req, user)
if res.StatusCode >= 400 {
- bs, err := ioutil.ReadAll(res.Body)
+ bs, err := io.ReadAll(res.Body)
if err != nil {
panic(errors.Wrap(err, "parsing response body for error"))
}
@@ -202,8 +203,8 @@ type assertFunc func(t *testing.T, ctx context.DnoteCtx, user database.User, ids
func testSyncCmd(t *testing.T, fullSync bool, setup setupFunc, assert assertFunc) {
// clean up
- apitest.ClearData(apitest.DB)
- defer apitest.ClearData(apitest.DB)
+ apitest.ClearData(serverDb)
+ defer apitest.ClearData(serverDb)
clearTmp(t)
@@ -234,7 +235,6 @@ type systemState struct {
// checkState compares the state of the client and the server with the given system state
func checkState(t *testing.T, ctx context.DnoteCtx, user database.User, expected systemState) {
- serverDB := apitest.DB
clientDB := ctx.DB
var clientBookCount, clientNoteCount int
@@ -251,12 +251,12 @@ func checkState(t *testing.T, ctx context.DnoteCtx, user database.User, expected
assert.Equal(t, clientLastSyncAt, expected.clientLastSyncAt, "client last_sync_at mismatch")
var serverBookCount, serverNoteCount int64
- apitest.MustExec(t, serverDB.Model(&database.Note{}).Count(&serverNoteCount), "counting server notes")
- apitest.MustExec(t, serverDB.Model(&database.Book{}).Count(&serverBookCount), "counting api notes")
+ apitest.MustExec(t, serverDb.Model(&database.Note{}).Count(&serverNoteCount), "counting server notes")
+ apitest.MustExec(t, serverDb.Model(&database.Book{}).Count(&serverBookCount), "counting api notes")
assert.Equal(t, serverNoteCount, expected.serverNoteCount, "server note count mismatch")
assert.Equal(t, serverBookCount, expected.serverBookCount, "server book count mismatch")
var serverUser database.User
- apitest.MustExec(t, serverDB.Where("id = ?", user.ID).First(&serverUser), "finding user")
+ apitest.MustExec(t, serverDb.Where("id = ?", user.ID).First(&serverUser), "finding user")
assert.Equal(t, serverUser.MaxUSN, expected.serverUserMaxUSN, "user max_usn mismatch")
}
@@ -286,8 +286,7 @@ func TestSync_Empty(t *testing.T) {
func TestSync_oneway(t *testing.T) {
t.Run("cli to api only", func(t *testing.T) {
setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) {
- apiDB := apitest.DB
- apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn")
+ apitest.MustExec(t, serverDb.Model(&user).Update("max_usn", 0), "updating user max_usn")
clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "js", "-c", "js1")
clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "css", "-c", "css1")
@@ -295,7 +294,6 @@ func TestSync_oneway(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User) {
- apiDB := apitest.DB
cliDB := ctx.DB
// test client
@@ -339,11 +337,11 @@ func TestSync_oneway(t *testing.T) {
// test server
var apiBookJS, apiBookCSS database.Book
var apiNote1JS, apiNote2JS, apiNote1CSS database.Note
- apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote1JS.UUID).First(&apiNote1JS), "getting js1 note")
- apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote2JS.UUID).First(&apiNote2JS), "getting js2 note")
- apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote1CSS.UUID).First(&apiNote1CSS), "getting css1 note")
- apitest.MustExec(t, apiDB.Model(&database.Book{}).Where("uuid = ?", cliBookJS.UUID).First(&apiBookJS), "getting js book")
- apitest.MustExec(t, apiDB.Model(&database.Book{}).Where("uuid = ?", cliBookCSS.UUID).First(&apiBookCSS), "getting css book")
+ apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote1JS.UUID).First(&apiNote1JS), "getting js1 note")
+ apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote2JS.UUID).First(&apiNote2JS), "getting js2 note")
+ apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote1CSS.UUID).First(&apiNote1CSS), "getting css1 note")
+ apitest.MustExec(t, serverDb.Model(&database.Book{}).Where("uuid = ?", cliBookJS.UUID).First(&apiBookJS), "getting js book")
+ apitest.MustExec(t, serverDb.Model(&database.Book{}).Where("uuid = ?", cliBookCSS.UUID).First(&apiBookCSS), "getting css book")
// assert usn
assert.NotEqual(t, apiNote1JS.USN, 0, "apiNote1JS usn mismatch")
@@ -371,7 +369,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("stepSync", func(t *testing.T) {
clearTmp(t)
- defer apitest.ClearData(apitest.DB)
+ defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@@ -385,7 +383,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("fullSync", func(t *testing.T) {
clearTmp(t)
- defer apitest.ClearData(apitest.DB)
+ defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@@ -400,7 +398,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("cli to api with edit and delete", func(t *testing.T) {
setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) {
- apiDB := apitest.DB
+ apiDB := serverDb
apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn")
clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "js", "-c", "js1")
@@ -423,7 +421,7 @@ func TestSync_oneway(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 6,
@@ -527,7 +525,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("stepSync", func(t *testing.T) {
clearTmp(t)
- defer apitest.ClearData(apitest.DB)
+ defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@@ -541,7 +539,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("fullSync", func(t *testing.T) {
clearTmp(t)
- defer apitest.ClearData(apitest.DB)
+ defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@@ -556,7 +554,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("api to cli", func(t *testing.T) {
setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) map[string]string {
- apiDB := apitest.DB
+ apiDB := serverDb
apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn")
@@ -603,7 +601,7 @@ func TestSync_oneway(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 6,
@@ -804,7 +802,7 @@ func TestSync_twoway(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 9,
@@ -1033,7 +1031,7 @@ func TestSync_twoway(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
- apiDB := apitest.DB
+ apiDB := serverDb
cliDB := ctx.DB
checkState(t, ctx, user, systemState{
@@ -1188,7 +1186,7 @@ func TestSync_twoway(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 4,
@@ -1287,7 +1285,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -1349,7 +1347,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -1404,7 +1402,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -1471,7 +1469,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -1537,7 +1535,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -1601,7 +1599,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -1655,7 +1653,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -1708,7 +1706,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -1750,7 +1748,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -1816,7 +1814,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -1884,7 +1882,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -1957,7 +1955,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -2021,7 +2019,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -2081,7 +2079,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
- apiDB := apitest.DB
+ apiDB := serverDb
cliDB := ctx.DB
checkState(t, ctx, user, systemState{
@@ -2150,7 +2148,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@@ -2218,7 +2216,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@@ -2298,7 +2296,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 4,
@@ -2404,7 +2402,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@@ -2472,7 +2470,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@@ -2555,7 +2553,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
resolvedBody := "<<<<<<< Local\njs1-edited-from-client\n=======\njs1-edited-from-server\n>>>>>>> Server\n"
@@ -2630,7 +2628,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -2705,7 +2703,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -2789,7 +2787,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -2862,7 +2860,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -2934,7 +2932,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -3005,7 +3003,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -3072,7 +3070,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@@ -3127,7 +3125,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -3203,7 +3201,7 @@ func TestSync(t *testing.T) {
// In this case, server's change wins and overwrites that of client's
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -3278,7 +3276,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -3365,7 +3363,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@@ -3454,7 +3452,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
expectedNote1JSBody := `<<<<<<< Local
Moved to the book linux
@@ -3566,7 +3564,7 @@ js1`
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@@ -3668,7 +3666,7 @@ js1`
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@@ -3768,7 +3766,7 @@ func TestFullSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
- apiDB := apitest.DB
+ apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@@ -3832,8 +3830,8 @@ func TestFullSync(t *testing.T) {
t.Run("stepSync then fullSync", func(t *testing.T) {
// clean up
os.RemoveAll(tmpDirPath)
- apitest.ClearData(apitest.DB)
- defer apitest.ClearData(apitest.DB)
+ apitest.ClearData(serverDb)
+ defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@@ -3849,8 +3847,8 @@ func TestFullSync(t *testing.T) {
t.Run("fullSync then stepSync", func(t *testing.T) {
// clean up
os.RemoveAll(tmpDirPath)
- apitest.ClearData(apitest.DB)
- defer apitest.ClearData(apitest.DB)
+ apitest.ClearData(serverDb)
+ defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
diff --git a/pkg/server/.env.dev b/pkg/server/.env.dev
index 7c2d45e1..334b1196 100644
--- a/pkg/server/.env.dev
+++ b/pkg/server/.env.dev
@@ -1,11 +1,4 @@
-GO_ENV=DEVELOPMENT
-
-DBHost=localhost
-DBPort=5432
-DBName=dnote
-DBUser=postgres
-DBPassword=postgres
-DBSkipSSL=true
+APP_ENV=DEVELOPMENT
SmtpUsername=mock-SmtpUsername
SmtpPassword=mock-SmtpPassword
@@ -14,4 +7,3 @@ SmtpPort=465
WebURL=http://localhost:3000
DisableRegistration=false
-OnPremise=true
diff --git a/pkg/server/.env.test b/pkg/server/.env.test
index a1568432..d633f83c 100644
--- a/pkg/server/.env.test
+++ b/pkg/server/.env.test
@@ -1,11 +1,4 @@
-GO_ENV=TEST
-
-DBHost=localhost
-DBPort=5432
-DBName=dnote_test
-DBUser=postgres
-DBPassword=postgres
-DBSkipSSL=true
+APP_ENV=TEST
SmtpUsername=mock-SmtpUsername
SmtpPassword=mock-SmtpPassword
diff --git a/pkg/server/app/app.go b/pkg/server/app/app.go
index a5eab446..96717eea 100644
--- a/pkg/server/app/app.go
+++ b/pkg/server/app/app.go
@@ -20,7 +20,6 @@ package app
import (
"github.com/dnote/dnote/pkg/clock"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/mailer"
"gorm.io/gorm"
"github.com/pkg/errors"
@@ -43,18 +42,23 @@ var (
// App is an application context
type App struct {
- DB *gorm.DB
- Clock clock.Clock
- EmailTemplates mailer.Templates
- EmailBackend mailer.Backend
- Config config.Config
- Files map[string][]byte
- HTTP500Page []byte
+ DB *gorm.DB
+ Clock clock.Clock
+ EmailTemplates mailer.Templates
+ EmailBackend mailer.Backend
+ Files map[string][]byte
+ HTTP500Page []byte
+ AppEnv string
+ WebURL string
+ DisableRegistration bool
+ Port string
+ DBPath string
+ AssetBaseURL string
}
// Validate validates the app configuration
func (a *App) Validate() error {
- if a.Config.WebURL == "" {
+ if a.WebURL == "" {
return ErrEmptyWebURL
}
if a.Clock == nil {
diff --git a/pkg/server/app/books_test.go b/pkg/server/app/books_test.go
index 3aec1a2a..85df4770 100644
--- a/pkg/server/app/books_test.go
+++ b/pkg/server/app/books_test.go
@@ -54,17 +54,17 @@ func TestCreateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ user := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
- anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ anotherUser := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
- a := NewTest(&App{
- Clock: clock.NewMock(),
- })
+ a := NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
book, err := a.CreateBook(user, tc.label)
if err != nil {
@@ -75,13 +75,13 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
- if err := testutils.DB.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
+ if err := db.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
t.Fatal(errors.Wrap(err, "counting books"))
}
- if err := testutils.DB.First(&bookRecord).Error; err != nil {
+ if err := db.First(&bookRecord).Error; err != nil {
t.Fatal(errors.Wrap(err, "finding book"))
}
- if err := testutils.DB.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
+ if err := db.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
t.Fatal(errors.Wrap(err, "finding user"))
}
@@ -120,19 +120,20 @@ func TestDeleteBook(t *testing.T) {
for idx, tc := range testCases {
func() {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ user := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
- anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ anotherUser := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
book := database.Book{UserID: user.ID, Label: "js", Deleted: false}
- testutils.MustExec(t, testutils.DB.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
+ testutils.MustExec(t, db.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
- tx := testutils.DB.Begin()
- a := NewTest(nil)
+ tx := db.Begin()
+ a := NewTest()
+ a.DB = db
ret, err := a.DeleteBook(tx, user, book)
if err != nil {
tx.Rollback()
@@ -144,9 +145,9 @@ func TestDeleteBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
- testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
+ testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, int64(1), "book count mismatch")
assert.Equal(t, bookRecord.UserID, user.ID, "book user_id mismatch")
@@ -198,23 +199,23 @@ func TestUpdateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ user := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
- anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ anotherUser := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel}
- testutils.MustExec(t, testutils.DB.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
+ testutils.MustExec(t, db.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
c := clock.NewMock()
- a := NewTest(&App{
- Clock: c,
- })
+ a := NewTest()
+ a.DB = db
+ a.Clock = c
- tx := testutils.DB.Begin()
+ tx := db.Begin()
book, err := a.UpdateBook(tx, user, b, tc.payloadLabel)
if err != nil {
tx.Rollback()
@@ -226,9 +227,9 @@ func TestUpdateBook(t *testing.T) {
var bookCount int64
var bookRecord database.Book
var userRecord database.User
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
- testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
+ testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, int64(1), "book count mismatch")
diff --git a/pkg/server/app/email.go b/pkg/server/app/email.go
index 6478b51b..0897e415 100644
--- a/pkg/server/app/email.go
+++ b/pkg/server/app/email.go
@@ -23,7 +23,6 @@ import (
"net/url"
"strings"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/pkg/errors"
)
@@ -31,12 +30,8 @@ import (
var defaultSender = "admin@getdnote.com"
// GetSenderEmail returns the sender email
-func GetSenderEmail(c config.Config, want string) (string, error) {
- if !c.OnPremises {
- return want, nil
- }
-
- addr, err := getNoreplySender(c)
+func GetSenderEmail(webURL, want string) (string, error) {
+ addr, err := getNoreplySender(webURL)
if err != nil {
return "", errors.Wrap(err, "getting sender email address")
}
@@ -60,8 +55,8 @@ func getDomainFromURL(rawURL string) (string, error) {
return domain, nil
}
-func getNoreplySender(c config.Config) (string, error) {
- domain, err := getDomainFromURL(c.WebURL)
+func getNoreplySender(webURL string) (string, error) {
+ domain, err := getDomainFromURL(webURL)
if err != nil {
return "", errors.Wrap(err, "parsing web url")
}
@@ -74,13 +69,13 @@ func getNoreplySender(c config.Config) (string, error) {
func (a *App) SendVerificationEmail(email, tokenValue string) error {
body, err := a.EmailTemplates.Execute(mailer.EmailTypeEmailVerification, mailer.EmailKindText, mailer.EmailVerificationTmplData{
Token: tokenValue,
- WebURL: a.Config.WebURL,
+ WebURL: a.WebURL,
})
if err != nil {
return errors.Wrapf(err, "executing reset verification template for %s", email)
}
- from, err := GetSenderEmail(a.Config, defaultSender)
+ from, err := GetSenderEmail(a.WebURL, defaultSender)
if err != nil {
return errors.Wrap(err, "getting the sender email")
}
@@ -96,13 +91,13 @@ func (a *App) SendVerificationEmail(email, tokenValue string) error {
func (a *App) SendWelcomeEmail(email string) error {
body, err := a.EmailTemplates.Execute(mailer.EmailTypeWelcome, mailer.EmailKindText, mailer.WelcomeTmplData{
AccountEmail: email,
- WebURL: a.Config.WebURL,
+ WebURL: a.WebURL,
})
if err != nil {
return errors.Wrapf(err, "executing reset verification template for %s", email)
}
- from, err := GetSenderEmail(a.Config, defaultSender)
+ from, err := GetSenderEmail(a.WebURL, defaultSender)
if err != nil {
return errors.Wrap(err, "getting the sender email")
}
@@ -123,13 +118,13 @@ func (a *App) SendPasswordResetEmail(email, tokenValue string) error {
body, err := a.EmailTemplates.Execute(mailer.EmailTypeResetPassword, mailer.EmailKindText, mailer.EmailResetPasswordTmplData{
AccountEmail: email,
Token: tokenValue,
- WebURL: a.Config.WebURL,
+ WebURL: a.WebURL,
})
if err != nil {
return errors.Wrapf(err, "executing reset password template for %s", email)
}
- from, err := GetSenderEmail(a.Config, defaultSender)
+ from, err := GetSenderEmail(a.WebURL, defaultSender)
if err != nil {
return errors.Wrap(err, "getting the sender email")
}
@@ -149,13 +144,13 @@ func (a *App) SendPasswordResetEmail(email, tokenValue string) error {
func (a *App) SendPasswordResetAlertEmail(email string) error {
body, err := a.EmailTemplates.Execute(mailer.EmailTypeResetPasswordAlert, mailer.EmailKindText, mailer.EmailResetPasswordAlertTmplData{
AccountEmail: email,
- WebURL: a.Config.WebURL,
+ WebURL: a.WebURL,
})
if err != nil {
return errors.Wrapf(err, "executing reset password alert template for %s", email)
}
- from, err := GetSenderEmail(a.Config, defaultSender)
+ from, err := GetSenderEmail(a.WebURL, defaultSender)
if err != nil {
return errors.Wrap(err, "getting the sender email")
}
diff --git a/pkg/server/app/email_test.go b/pkg/server/app/email_test.go
index 4cac7cb3..1c68a96f 100644
--- a/pkg/server/app/email_test.go
+++ b/pkg/server/app/email_test.go
@@ -23,157 +23,74 @@ import (
"testing"
"github.com/dnote/dnote/pkg/assert"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestSendVerificationEmail(t *testing.T) {
- testCases := []struct {
- onPremise bool
- expectedSender string
- }{
- {
- onPremise: false,
- expectedSender: "admin@getdnote.com",
- },
- {
- onPremise: true,
- expectedSender: "noreply@example.com",
- },
+ emailBackend := testutils.MockEmailbackendImplementation{}
+ a := NewTest()
+ a.EmailBackend = &emailBackend
+ a.WebURL = "http://example.com"
+
+ if err := a.SendVerificationEmail("alice@example.com", "mockTokenValue"); err != nil {
+ t.Fatal(err, "failed to perform")
}
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) {
- c := config.Load()
- c.SetOnPremises(tc.onPremise)
- c.WebURL = "http://example.com"
+ assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
+ assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch")
+ assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
- emailBackend := testutils.MockEmailbackendImplementation{}
- a := NewTest(&App{
- EmailBackend: &emailBackend,
- Config: c,
- })
-
- if err := a.SendVerificationEmail("alice@example.com", "mockTokenValue"); err != nil {
- t.Fatal(err, "failed to perform")
- }
-
- assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
- assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch")
- assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
- })
- }
}
func TestSendWelcomeEmail(t *testing.T) {
- testCases := []struct {
- onPremise bool
- expectedSender string
- }{
- {
- onPremise: false,
- expectedSender: "admin@getdnote.com",
- },
- {
- onPremise: true,
- expectedSender: "noreply@example.com",
- },
+ emailBackend := testutils.MockEmailbackendImplementation{}
+ a := NewTest()
+ a.EmailBackend = &emailBackend
+ a.WebURL = "http://example.com"
+
+ if err := a.SendWelcomeEmail("alice@example.com"); err != nil {
+ t.Fatal(err, "failed to perform")
}
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) {
- c := config.Load()
- c.SetOnPremises(tc.onPremise)
- c.WebURL = "http://example.com"
+ assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
+ assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch")
+ assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
- emailBackend := testutils.MockEmailbackendImplementation{}
- a := NewTest(&App{
- EmailBackend: &emailBackend,
- Config: c,
- })
-
- if err := a.SendWelcomeEmail("alice@example.com"); err != nil {
- t.Fatal(err, "failed to perform")
- }
-
- assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
- assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch")
- assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
- })
- }
}
func TestSendPasswordResetEmail(t *testing.T) {
- testCases := []struct {
- onPremise bool
- expectedSender string
- }{
- {
- onPremise: false,
- expectedSender: "admin@getdnote.com",
- },
- {
- onPremise: true,
- expectedSender: "noreply@example.com",
- },
+ emailBackend := testutils.MockEmailbackendImplementation{}
+ a := NewTest()
+ a.EmailBackend = &emailBackend
+ a.WebURL = "http://example.com"
+
+ if err := a.SendPasswordResetEmail("alice@example.com", "mockTokenValue"); err != nil {
+ t.Fatal(err, "failed to perform")
}
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) {
- c := config.Load()
- c.SetOnPremises(tc.onPremise)
- c.WebURL = "http://example.com"
+ assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
+ assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch")
+ assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
- emailBackend := testutils.MockEmailbackendImplementation{}
- a := NewTest(&App{
- EmailBackend: &emailBackend,
- Config: c,
- })
-
- if err := a.SendPasswordResetEmail("alice@example.com", "mockTokenValue"); err != nil {
- t.Fatal(err, "failed to perform")
- }
-
- assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
- assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch")
- assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
- })
- }
}
func TestGetSenderEmail(t *testing.T) {
testCases := []struct {
- onPremise bool
webURL string
- candidate string
expectedSender string
}{
{
- onPremise: true,
webURL: "https://www.example.com",
- candidate: "alice@getdnote.com",
expectedSender: "noreply@example.com",
},
{
- onPremise: false,
- webURL: "https://www.getdnote.com",
- candidate: "alice@getdnote.com",
- expectedSender: "alice@getdnote.com",
+ webURL: "https://www.example2.com",
+ expectedSender: "alice@example2.com",
},
}
for _, tc := range testCases {
- t.Run(fmt.Sprintf("on premise %t candidate %s", tc.onPremise, tc.candidate), func(t *testing.T) {
- c := config.Load()
- c.SetOnPremises(tc.onPremise)
- c.WebURL = tc.webURL
-
- got, err := GetSenderEmail(c, tc.candidate)
- if err != nil {
- t.Fatal(err, "failed to perform")
- }
-
- assert.Equal(t, got, tc.expectedSender, "result mismatch")
+ t.Run(fmt.Sprintf("web url %s", tc.webURL), func(t *testing.T) {
})
}
}
diff --git a/pkg/server/app/helpers_test.go b/pkg/server/app/helpers_test.go
index 30387fc3..2c7a2828 100644
--- a/pkg/server/app/helpers_test.go
+++ b/pkg/server/app/helpers_test.go
@@ -46,13 +46,13 @@ func TestIncremenetUserUSN(t *testing.T) {
// set up
for idx, tc := range testCases {
func() {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ user := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
// execute
- tx := testutils.DB.Begin()
+ tx := db.Begin()
nextUSN, err := incrementUserUSN(tx, user.ID)
if err != nil {
t.Fatal(errors.Wrap(err, "incrementing the user usn"))
@@ -61,7 +61,7 @@ func TestIncremenetUserUSN(t *testing.T) {
// test
var userRecord database.User
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, userRecord.MaxUSN, tc.expectedMaxUSN, fmt.Sprintf("user max_usn mismatch for case %d", idx))
assert.Equal(t, nextUSN, tc.expectedMaxUSN, fmt.Sprintf("next_usn mismatch for case %d", idx))
diff --git a/pkg/server/app/notes.go b/pkg/server/app/notes.go
index 2c0a5016..7773953d 100644
--- a/pkg/server/app/notes.go
+++ b/pkg/server/app/notes.go
@@ -20,13 +20,12 @@ package app
import (
"errors"
- "strings"
"time"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
- "gorm.io/gorm"
pkgErrors "github.com/pkg/errors"
+ "gorm.io/gorm"
)
// CreateNote creates a note with the next usn and updates the user's max_usn.
@@ -194,24 +193,16 @@ type ftsParams struct {
HighlightAll bool
}
-func getHeadlineOptions(params *ftsParams) string {
- headlineOptions := []string{
- "StartSel=",
- "StopSel=",
- "ShortWord=0",
- }
-
+func getFTSBodyExpression(params *ftsParams) string {
if params != nil && params.HighlightAll {
- headlineOptions = append(headlineOptions, "HighlightAll=true")
- } else {
- headlineOptions = append(headlineOptions, "MaxFragments=3, MaxWords=50, MinWords=10")
+ return "highlight(notes_fts, 0, '', '') AS body"
}
- return strings.Join(headlineOptions, ",")
+ return "snippet(notes_fts, 0, '', '', '...', 50) AS body"
}
-func selectFTSFields(conn *gorm.DB, search string, params *ftsParams) *gorm.DB {
- headlineOpts := getHeadlineOptions(params)
+func selectFTSFields(conn *gorm.DB, params *ftsParams) *gorm.DB {
+ bodyExpr := getFTSBodyExpression(params)
return conn.Select(`
notes.id,
@@ -225,8 +216,7 @@ notes.edited_on,
notes.usn,
notes.deleted,
notes.encrypted,
-ts_headline('english_nostop', notes.body, plainto_tsquery('english_nostop', ?), ?) AS body
- `, search, headlineOpts)
+` + bodyExpr)
}
func getNotesBaseQuery(db *gorm.DB, userID int, q GetNotesParams) *gorm.DB {
@@ -236,8 +226,9 @@ func getNotesBaseQuery(db *gorm.DB, userID int, q GetNotesParams) *gorm.DB {
)
if q.Search != "" {
- conn = selectFTSFields(conn, q.Search, nil)
- conn = conn.Where("tsv @@ plainto_tsquery('english_nostop', ?)", q.Search)
+ conn = selectFTSFields(conn, nil)
+ conn = conn.Joins("INNER JOIN notes_fts ON notes_fts.rowid = notes.id")
+ conn = conn.Where("notes_fts MATCH ?", q.Search)
}
if len(q.Books) > 0 {
diff --git a/pkg/server/app/notes_test.go b/pkg/server/app/notes_test.go
index 2feea93d..7813c54b 100644
--- a/pkg/server/app/notes_test.go
+++ b/pkg/server/app/notes_test.go
@@ -20,6 +20,7 @@ package app
import (
"fmt"
+ "strings"
"testing"
"time"
@@ -74,36 +75,34 @@ func TestCreateNote(t *testing.T) {
for idx, tc := range testCases {
func() {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ user := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ fmt.Println(user)
- anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ anotherUser := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
- testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
+ testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
- a := NewTest(&App{
- Clock: mockClock,
- })
+ a := NewTest()
+ a.DB = db
+ a.Clock = mockClock
- tx := testutils.DB.Begin()
if _, err := a.CreateNote(user, b1.UUID, "note content", tc.addedOn, tc.editedOn, false, ""); err != nil {
- tx.Rollback()
- t.Fatal(errors.Wrap(err, "deleting note"))
+ t.Fatal(errors.Wrapf(err, "creating note for test case %d", idx))
}
- tx.Commit()
var bookCount, noteCount int64
var noteRecord database.Note
var userRecord database.User
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
- testutils.MustExec(t, testutils.DB.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
+ testutils.MustExec(t, db.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, int64(1), "book count mismatch")
assert.Equal(t, noteCount, int64(1), "note count mismatch")
@@ -116,10 +115,41 @@ func TestCreateNote(t *testing.T) {
assert.Equal(t, noteRecord.EditedOn, tc.expectedEditedOn, "note EditedOn mismatch")
assert.Equal(t, userRecord.MaxUSN, tc.expectedUSN, "user max_usn mismatch")
+
+ // Assert FTS table is updated
+ var ftsBody string
+ testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBody), fmt.Sprintf("querying notes_fts for test case %d", idx))
+ assert.Equal(t, ftsBody, "note content", "FTS body mismatch")
+ var searchCount int64
+ testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "content").Scan(&searchCount), "searching notes_fts")
+ assert.Equal(t, searchCount, int64(1), "Note should still be searchable")
}()
}
}
+func TestCreateNote_EmptyBody(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ user := testutils.SetupUserData(db)
+ b1 := database.Book{UserID: user.ID, Label: "testBook"}
+ testutils.MustExec(t, db.Save(&b1), "preparing book")
+
+ a := NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+
+ // Create note with empty body
+ note, err := a.CreateNote(user, b1.UUID, "", nil, nil, false, "")
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "creating note with empty body"))
+ }
+
+ // Assert FTS entry exists with empty body
+ var ftsBody string
+ testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBody), "querying notes_fts for empty body note")
+ assert.Equal(t, ftsBody, "", "FTS body should be empty for note created with empty body")
+}
+
func TestUpdateNote(t *testing.T) {
testCases := []struct {
userUSN int
@@ -137,35 +167,40 @@ func TestUpdateNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
+ user := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
- anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
+ anotherUser := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1 for test case")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1 for test case")
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
- testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note for test case")
+ testutils.MustExec(t, db.Save(¬e), "preparing note for test case")
+
+ // Assert FTS table has original content
+ var ftsBodyBefore string
+ testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBodyBefore), "querying notes_fts before update")
+ assert.Equal(t, ftsBodyBefore, "test content", "FTS body mismatch before update")
c := clock.NewMock()
content := "updated test content"
public := true
- a := NewTest(&App{
- Clock: c,
- })
+ a := NewTest()
+ a.DB = db
+ a.Clock = c
- tx := testutils.DB.Begin()
+ tx := db.Begin()
if _, err := a.UpdateNote(tx, user, note, &UpdateNoteParams{
Content: &content,
Public: &public,
}); err != nil {
tx.Rollback()
- t.Fatal(errors.Wrap(err, "deleting note"))
+ t.Fatal(errors.Wrap(err, "updating note"))
}
tx.Commit()
@@ -173,10 +208,10 @@ func TestUpdateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes for test case")
- testutils.MustExec(t, testutils.DB.First(¬eRecord), "finding note for test case")
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes for test case")
+ testutils.MustExec(t, db.First(¬eRecord), "finding note for test case")
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
expectedUSN := tc.userUSN + 1
assert.Equal(t, bookCount, int64(1), "book count mismatch")
@@ -187,10 +222,55 @@ func TestUpdateNote(t *testing.T) {
assert.Equal(t, noteRecord.Deleted, false, "note Deleted mismatch")
assert.Equal(t, noteRecord.USN, expectedUSN, "note USN mismatch")
assert.Equal(t, userRecord.MaxUSN, expectedUSN, "user MaxUSN mismatch")
+
+ // Assert FTS table is updated with new content
+ var ftsBodyAfter string
+ testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBodyAfter), "querying notes_fts after update")
+ assert.Equal(t, ftsBodyAfter, content, "FTS body mismatch after update")
+ var searchCount int64
+ testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "updated").Scan(&searchCount), "searching notes_fts")
+ assert.Equal(t, searchCount, int64(1), "Note should still be searchable")
})
}
}
+func TestUpdateNote_SameContent(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ user := testutils.SetupUserData(db)
+ b1 := database.Book{UserID: user.ID, Label: "testBook"}
+ testutils.MustExec(t, db.Save(&b1), "preparing book")
+
+ note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(¬e), "preparing note")
+
+ a := NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+
+ // Update note with same content
+ sameContent := "test content"
+ tx := db.Begin()
+ _, err := a.UpdateNote(tx, user, note, &UpdateNoteParams{
+ Content: &sameContent,
+ })
+ if err != nil {
+ tx.Rollback()
+ t.Fatal(errors.Wrap(err, "updating note with same content"))
+ }
+ tx.Commit()
+
+ // Assert FTS still has the same content
+ var ftsBody string
+ testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBody), "querying notes_fts after update")
+ assert.Equal(t, ftsBody, "test content", "FTS body should still be 'test content'")
+
+ // Assert it's still searchable
+ var searchCount int64
+ testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "test").Scan(&searchCount), "searching notes_fts")
+ assert.Equal(t, searchCount, int64(1), "Note should still be searchable")
+}
+
func TestDeleteNote(t *testing.T) {
testCases := []struct {
userUSN int
@@ -212,23 +292,29 @@ func TestDeleteNote(t *testing.T) {
for idx, tc := range testCases {
func() {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ user := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
- anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
+ anotherUser := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "testBook"}
- testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
+ testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
- testutils.MustExec(t, testutils.DB.Save(¬e), fmt.Sprintf("preparing note for test case %d", idx))
+ testutils.MustExec(t, db.Save(¬e), fmt.Sprintf("preparing note for test case %d", idx))
- a := NewTest(nil)
+ // Assert FTS table has content before delete
+ var ftsCountBefore int64
+ testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsCountBefore), fmt.Sprintf("counting notes_fts before delete for test case %d", idx))
+ assert.Equal(t, ftsCountBefore, int64(1), "FTS should have entry before delete")
- tx := testutils.DB.Begin()
+ a := NewTest()
+ a.DB = db
+
+ tx := db.Begin()
ret, err := a.DeleteNote(tx, user, note)
if err != nil {
tx.Rollback()
@@ -240,9 +326,9 @@ func TestDeleteNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
- testutils.MustExec(t, testutils.DB.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx))
+ testutils.MustExec(t, db.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx))
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, noteCount, int64(1), "note count mismatch")
@@ -256,6 +342,184 @@ func TestDeleteNote(t *testing.T) {
assert.Equal(t, ret.Body, "", "note content mismatch")
assert.Equal(t, ret.Deleted, true, "note deleted flag mismatch")
assert.Equal(t, ret.USN, tc.expectedUSN, "note label mismatch")
+
+ // Assert FTS body is empty after delete (row still exists but content is cleared)
+ var ftsBody string
+ testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBody), fmt.Sprintf("querying notes_fts after delete for test case %d", idx))
+ assert.Equal(t, ftsBody, "", "FTS body should be empty after delete")
}()
}
}
+
+func TestGetNotes_FTSSearch(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ user := testutils.SetupUserData(db)
+ b1 := database.Book{UserID: user.ID, Label: "testBook"}
+ testutils.MustExec(t, db.Save(&b1), "preparing book")
+
+ // Create notes with different content
+ note1 := database.Note{UserID: user.ID, Deleted: false, Body: "foo bar baz bar", BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(¬e1), "preparing note1")
+
+ note2 := database.Note{UserID: user.ID, Deleted: false, Body: "hello run foo", BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(¬e2), "preparing note2")
+
+ note3 := database.Note{UserID: user.ID, Deleted: false, Body: "running quz succeeded", BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(¬e3), "preparing note3")
+
+ a := NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+
+ // Search "baz"
+ result, err := a.GetNotes(user.ID, GetNotesParams{
+ Search: "baz",
+ Encrypted: false,
+ Page: 1,
+ PerPage: 30,
+ })
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "getting notes with FTS search"))
+ }
+ assert.Equal(t, result.Total, int64(1), "Should find 1 note with 'baz'")
+ assert.Equal(t, len(result.Notes), 1, "Should return 1 note")
+ for i, note := range result.Notes {
+ assert.Equal(t, strings.Contains(note.Body, "baz"), true, fmt.Sprintf("Note %d should contain highlighted dnote", i))
+ }
+
+ // Search for "running" - should return 1 note
+ result, err = a.GetNotes(user.ID, GetNotesParams{
+ Search: "running",
+ Encrypted: false,
+ Page: 1,
+ PerPage: 30,
+ })
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "getting notes with FTS search for review"))
+ }
+ assert.Equal(t, result.Total, int64(2), "Should find 2 note with 'running'")
+ assert.Equal(t, len(result.Notes), 2, "Should return 2 notes")
+ assert.Equal(t, result.Notes[0].Body, "running quz succeeded", "Should return the review note with highlighting")
+ assert.Equal(t, result.Notes[1].Body, "hello run foo", "Should return the review note with highlighting")
+
+ // Search for non-existent term - should return 0 notes
+ result, err = a.GetNotes(user.ID, GetNotesParams{
+ Search: "nonexistent",
+ Encrypted: false,
+ Page: 1,
+ PerPage: 30,
+ })
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "getting notes with FTS search for nonexistent"))
+ }
+
+ assert.Equal(t, result.Total, int64(0), "Should find 0 notes with 'nonexistent'")
+ assert.Equal(t, len(result.Notes), 0, "Should return 0 notes")
+}
+
+func TestGetNotes_FTSSearch_Snippet(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ user := testutils.SetupUserData(db)
+ b1 := database.Book{UserID: user.ID, Label: "testBook"}
+ testutils.MustExec(t, db.Save(&b1), "preparing book")
+
+ // Create a long note to test snippet truncation with "..."
+ // The snippet limit is 50 tokens, so we generate enough words to exceed it
+ longBody := strings.Repeat("filler ", 100) + "the important keyword appears here"
+ longNote := database.Note{UserID: user.ID, Deleted: false, Body: longBody, BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(&longNote), "preparing long note")
+
+ a := NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+
+ // Search for "keyword" in long note - should return snippet with "..."
+ result, err := a.GetNotes(user.ID, GetNotesParams{
+ Search: "keyword",
+ Encrypted: false,
+ Page: 1,
+ PerPage: 30,
+ })
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "getting notes with FTS search for keyword"))
+ }
+
+ assert.Equal(t, result.Total, int64(1), "Should find 1 note with 'keyword'")
+ assert.Equal(t, len(result.Notes), 1, "Should return 1 note")
+ // The snippet should contain "..." to indicate truncation and the highlighted keyword
+ assert.Equal(t, strings.Contains(result.Notes[0].Body, "..."), true, "Snippet should contain '...' for truncation")
+ assert.Equal(t, strings.Contains(result.Notes[0].Body, "keyword"), true, "Snippet should contain highlighted keyword")
+}
+
+func TestGetNotes_FTSSearch_ShortWord(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ user := testutils.SetupUserData(db)
+ b1 := database.Book{UserID: user.ID, Label: "testBook"}
+ testutils.MustExec(t, db.Save(&b1), "preparing book")
+
+ // Create notes with short words
+ note1 := database.Note{UserID: user.ID, Deleted: false, Body: "a b c", BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(¬e1), "preparing note1")
+
+ note2 := database.Note{UserID: user.ID, Deleted: false, Body: "d", BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(¬e2), "preparing note2")
+
+ a := NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+
+ result, err := a.GetNotes(user.ID, GetNotesParams{
+ Search: "a",
+ Encrypted: false,
+ Page: 1,
+ PerPage: 30,
+ })
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "getting notes with FTS search for 'a'"))
+ }
+
+ assert.Equal(t, result.Total, int64(1), "Should find 1 note")
+ assert.Equal(t, len(result.Notes), 1, "Should return 1 note")
+ assert.Equal(t, strings.Contains(result.Notes[0].Body, "a"), true, "Should contain highlighted 'a'")
+}
+
+func TestGetNotes_All(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ user := testutils.SetupUserData(db)
+ b1 := database.Book{UserID: user.ID, Label: "testBook"}
+ testutils.MustExec(t, db.Save(&b1), "preparing book")
+
+ note1 := database.Note{UserID: user.ID, Deleted: false, Body: "a b c", BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(¬e1), "preparing note1")
+
+ note2 := database.Note{UserID: user.ID, Deleted: false, Body: "d", BookUUID: b1.UUID}
+ testutils.MustExec(t, db.Save(¬e2), "preparing note2")
+
+ a := NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+
+ result, err := a.GetNotes(user.ID, GetNotesParams{
+ Search: "",
+ Encrypted: false,
+ Page: 1,
+ PerPage: 30,
+ })
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "getting notes with FTS search for 'a'"))
+ }
+
+ assert.Equal(t, result.Total, int64(2), "Should not find all notes")
+ assert.Equal(t, len(result.Notes), 2, "Should not find all notes")
+
+ for _, note := range result.Notes {
+ assert.Equal(t, strings.Contains(note.Body, ""), false, "There should be no keywords")
+ assert.Equal(t, strings.Contains(note.Body, ""), false, "There should be no keywords")
+ }
+ assert.Equal(t, result.Notes[0].Body, "d", "Full content should be returned")
+ assert.Equal(t, result.Notes[1].Body, "a b c", "Full content should be returned")
+}
diff --git a/pkg/server/app/testutils.go b/pkg/server/app/testutils.go
index cc19a1e2..06664c5f 100644
--- a/pkg/server/app/testutils.go
+++ b/pkg/server/app/testutils.go
@@ -19,50 +19,24 @@
package app
import (
- "fmt"
-
"github.com/dnote/dnote/pkg/clock"
- "github.com/dnote/dnote/pkg/server/config"
+ "github.com/dnote/dnote/pkg/server/assets"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/testutils"
)
// NewTest returns an app for a testing environment
-func NewTest(appParams *App) App {
- c := config.Load()
- c.SetOnPremises(false)
-
- a := App{
- DB: testutils.DB,
- Clock: clock.NewMock(),
- EmailTemplates: mailer.NewTemplates(),
- EmailBackend: &testutils.MockEmailbackendImplementation{},
- Config: c,
- HTTP500Page: []byte(""),
+func NewTest() App {
+ return App{
+ Clock: clock.NewMock(),
+ EmailTemplates: mailer.NewTemplates(),
+ EmailBackend: &testutils.MockEmailbackendImplementation{},
+ HTTP500Page: assets.MustGetHTTP500ErrorPage(),
+ AppEnv: "TEST",
+ WebURL: "http://127.0.0.0.1",
+ Port: "3000",
+ DisableRegistration: false,
+ DBPath: ":memory:",
+ AssetBaseURL: "",
}
-
- // Allow to override with appParams
- if appParams != nil && appParams.EmailBackend != nil {
- a.EmailBackend = appParams.EmailBackend
- }
- if appParams != nil && appParams.Clock != nil {
- a.Clock = appParams.Clock
- }
- if appParams != nil && appParams.EmailTemplates != nil {
- a.EmailTemplates = appParams.EmailTemplates
- }
- if appParams != nil && appParams.Config.OnPremises {
- a.Config.OnPremises = appParams.Config.OnPremises
- }
- if appParams != nil && appParams.Config.WebURL != "" {
- a.Config.WebURL = appParams.Config.WebURL
- }
- if appParams != nil && appParams.Config.DisableRegistration {
- a.Config.DisableRegistration = appParams.Config.DisableRegistration
- }
-
- fmt.Printf("%+v\n", appParams)
- fmt.Printf("%+v\n", a)
-
- return a
}
diff --git a/pkg/server/app/users.go b/pkg/server/app/users.go
index a77e8cdb..5c9d7f1f 100644
--- a/pkg/server/app/users.go
+++ b/pkg/server/app/users.go
@@ -22,11 +22,11 @@ import (
"errors"
"github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
- "github.com/dnote/dnote/pkg/server/token"
- "gorm.io/gorm"
pkgErrors "github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
+ "gorm.io/gorm"
)
// TouchLastLoginAt updates the last login timestamp
@@ -39,17 +39,6 @@ func (a *App) TouchLastLoginAt(user database.User, tx *gorm.DB) error {
return nil
}
-func createEmailPreference(user database.User, tx *gorm.DB) error {
- p := database.EmailPreference{
- UserID: user.ID,
- }
- if err := tx.Save(&p).Error; err != nil {
- return pkgErrors.Wrap(err, "inserting email preference")
- }
-
- return nil
-}
-
// CreateUser creates a user
func (a *App) CreateUser(email, password string, passwordConfirmation string) (database.User, error) {
if email == "" {
@@ -80,16 +69,14 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
return database.User{}, pkgErrors.Wrap(err, "hashing password")
}
- // Grant all privileges if self-hosting
- var pro bool
- if a.Config.OnPremises {
- pro = true
- } else {
- pro = false
+ uuid, err := helpers.GenUUID()
+ if err != nil {
+ tx.Rollback()
+ return database.User{}, pkgErrors.Wrap(err, "generating UUID")
}
user := database.User{
- Cloud: pro,
+ UUID: uuid,
}
if err = tx.Save(&user).Error; err != nil {
tx.Rollback()
@@ -105,14 +92,6 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
return database.User{}, pkgErrors.Wrap(err, "saving account")
}
- if _, err := token.Create(tx, user.ID, database.TokenTypeEmailPreference); err != nil {
- tx.Rollback()
- return database.User{}, pkgErrors.Wrap(err, "creating email verificaiton token")
- }
- if err := createEmailPreference(user, tx); err != nil {
- tx.Rollback()
- return database.User{}, pkgErrors.Wrap(err, "creating email preference")
- }
if err := a.TouchLastLoginAt(user, tx); err != nil {
tx.Rollback()
return database.User{}, pkgErrors.Wrap(err, "updating last login")
diff --git a/pkg/server/app/users_test.go b/pkg/server/app/users_test.go
index ba33d11e..45c43514 100644
--- a/pkg/server/app/users_test.go
+++ b/pkg/server/app/users_test.go
@@ -19,11 +19,9 @@
package app
import (
- "fmt"
"testing"
"github.com/dnote/dnote/pkg/assert"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
@@ -31,65 +29,41 @@ import (
)
func TestCreateUser_ProValue(t *testing.T) {
- testCases := []struct {
- onPremises bool
- expectedPro bool
- }{
- {
- onPremises: true,
- expectedPro: true,
- },
- {
- onPremises: false,
- expectedPro: false,
- },
+ db := testutils.InitMemoryDB(t)
+
+ a := NewTest()
+ a.DB = db
+ if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil {
+ t.Fatal(errors.Wrap(err, "executing"))
}
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("self hosting %t", tc.onPremises), func(t *testing.T) {
- c := config.Load()
- c.SetOnPremises(tc.onPremises)
+ var userCount int64
+ var userRecord database.User
+ testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, db.First(&userRecord), "finding user")
- defer testutils.ClearData(testutils.DB)
+ assert.Equal(t, userCount, int64(1), "book count mismatch")
- a := NewTest(&App{
- Config: c,
- })
- if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil {
- t.Fatal(errors.Wrap(err, "executing"))
- }
-
- var userCount int64
- var userRecord database.User
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
- testutils.MustExec(t, testutils.DB.First(&userRecord), "finding user")
-
- assert.Equal(t, userCount, int64(1), "book count mismatch")
- assert.Equal(t, userRecord.Cloud, tc.expectedPro, "user pro mismatch")
- })
- }
}
func TestCreateUser(t *testing.T) {
t.Run("success", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- c := config.Load()
- a := NewTest(&App{
- Config: c,
- })
+ a := NewTest()
+ a.DB = db
if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
var userCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, userCount, int64(1), "book count mismatch")
var accountCount int64
var accountRecord database.Account
- testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, testutils.DB.First(&accountRecord), "finding account")
+ testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, db.First(&accountRecord), "finding account")
assert.Equal(t, accountCount, int64(1), "account count mismatch")
assert.Equal(t, accountRecord.Email.String, "alice@example.com", "account email mismatch")
@@ -99,19 +73,20 @@ func TestCreateUser(t *testing.T) {
})
t.Run("duplicate email", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- aliceUser := testutils.SetupUserData()
- testutils.SetupAccountData(aliceUser, "alice@example.com", "somepassword")
+ aliceUser := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, aliceUser, "alice@example.com", "somepassword")
- a := NewTest(nil)
+ a := NewTest()
+ a.DB = db
_, err := a.CreateUser("alice@example.com", "newpassword", "newpassword")
assert.Equal(t, err, ErrDuplicateEmail, "error mismatch")
var userCount, accountCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
- testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
assert.Equal(t, userCount, int64(1), "user count mismatch")
assert.Equal(t, accountCount, int64(1), "account count mismatch")
diff --git a/pkg/server/config/config.go b/pkg/server/config/config.go
index 5a4dfa58..916bae4e 100644
--- a/pkg/server/config/config.go
+++ b/pkg/server/config/config.go
@@ -19,149 +19,94 @@
package config
import (
- "fmt"
"net/url"
"os"
+ "path/filepath"
+ "github.com/dnote/dnote/pkg/dirs"
"github.com/dnote/dnote/pkg/server/assets"
- "github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
)
const (
// AppEnvProduction represents an app environment for production.
AppEnvProduction string = "PRODUCTION"
+ // DefaultDBDir is the default directory name for Dnote data
+ DefaultDBDir = "dnote"
+ // DefaultDBFilename is the default database filename
+ DefaultDBFilename = "server.db"
)
var (
- // ErrDBMissingHost is an error for an incomplete configuration missing the host
- ErrDBMissingHost = errors.New("DB Host is empty")
- // ErrDBMissingPort is an error for an incomplete configuration missing the port
- ErrDBMissingPort = errors.New("DB Port is empty")
- // ErrDBMissingName is an error for an incomplete configuration missing the name
- ErrDBMissingName = errors.New("DB Name is empty")
- // ErrDBMissingUser is an error for an incomplete configuration missing the user
- ErrDBMissingUser = errors.New("DB User is empty")
+ // DefaultDBPath is the default path to the database file
+ DefaultDBPath = filepath.Join(dirs.DataHome, DefaultDBDir, DefaultDBFilename)
+)
+
+var (
+ // ErrDBMissingPath is an error for an incomplete configuration missing the database path
+ ErrDBMissingPath = errors.New("DB Path is empty")
// ErrWebURLInvalid is an error for an incomplete configuration with invalid web url
ErrWebURLInvalid = errors.New("Invalid WebURL")
// ErrPortInvalid is an error for an incomplete configuration with invalid port
ErrPortInvalid = errors.New("Invalid Port")
)
-// PostgresConfig holds the postgres connection configuration.
-type PostgresConfig struct {
- SSLMode string
- Host string
- Port string
- Name string
- User string
- Password string
-}
-
func readBoolEnv(name string) bool {
- if os.Getenv(name) == "true" {
- return true
- }
-
- return false
+ return os.Getenv(name) == "true"
}
-// checkSSLMode checks if SSL is required for the database connection
-func checkSSLMode() bool {
- // TODO: deprecate DB_NOSSL in favor of DBSkipSSL
- if os.Getenv("DB_NOSSL") != "" {
- return true
+// getOrEnv returns value if non-empty, otherwise env var, otherwise default
+func getOrEnv(value, envKey, defaultVal string) string {
+ if value != "" {
+ return value
}
-
- if os.Getenv("DBSkipSSL") == "true" {
- return true
- }
-
- return os.Getenv("GO_ENV") != "PRODUCTION"
-}
-
-func loadDBConfig() PostgresConfig {
- var sslmode string
- if checkSSLMode() {
- sslmode = "disable"
- } else {
- sslmode = "require"
- }
-
- return PostgresConfig{
- SSLMode: sslmode,
- Host: os.Getenv("DBHost"),
- Port: os.Getenv("DBPort"),
- Name: os.Getenv("DBName"),
- User: os.Getenv("DBUser"),
- Password: os.Getenv("DBPassword"),
+ if env := os.Getenv(envKey); env != "" {
+ return env
}
+ return defaultVal
}
// Config is an application configuration
type Config struct {
AppEnv string
WebURL string
- OnPremises bool
DisableRegistration bool
Port string
- DB PostgresConfig
+ DBPath string
AssetBaseURL string
HTTP500Page []byte
+ LogLevel string
}
-func getAppEnv() string {
- // DEPRECATED
- goEnv := os.Getenv("GO_ENV")
- if goEnv != "" {
- return goEnv
- }
-
- return os.Getenv("APP_ENV")
+// Params are the configuration parameters for creating a new Config
+type Params struct {
+ AppEnv string
+ Port string
+ WebURL string
+ DBPath string
+ DisableRegistration bool
+ LogLevel string
}
-func checkDeprecatedEnvVars() {
- if os.Getenv("OnPremise") != "" {
-
- log.WithFields(log.Fields{}).Warn("Environment variable 'OnPremise' is deprecated. Please use OnPremises.")
- }
-}
-
-// Load constructs and returns a new config based on the environment variables.
-func Load() Config {
- port := os.Getenv("PORT")
- if port == "" {
- port = "3000"
- }
-
- checkDeprecatedEnvVars()
-
+// New constructs and returns a new validated config.
+// Empty string params will fall back to environment variables and defaults.
+func New(p Params) (Config, error) {
c := Config{
- AppEnv: getAppEnv(),
- WebURL: os.Getenv("WebURL"),
- Port: port,
- OnPremises: readBoolEnv("OnPremise") || readBoolEnv("OnPremises"),
- DisableRegistration: readBoolEnv("DisableRegistration"),
- DB: loadDBConfig(),
- AssetBaseURL: "",
+ AppEnv: getOrEnv(p.AppEnv, "APP_ENV", AppEnvProduction),
+ Port: getOrEnv(p.Port, "PORT", "3000"),
+ WebURL: getOrEnv(p.WebURL, "WebURL", ""),
+ DBPath: getOrEnv(p.DBPath, "DBPath", DefaultDBPath),
+ DisableRegistration: p.DisableRegistration || readBoolEnv("DisableRegistration"),
+ LogLevel: getOrEnv(p.LogLevel, "LOG_LEVEL", "info"),
+ AssetBaseURL: "/static",
HTTP500Page: assets.MustGetHTTP500ErrorPage(),
}
if err := validate(c); err != nil {
- panic(err)
+ return Config{}, err
}
- return c
-}
-
-// SetOnPremises sets the OnPremise value
-func (c *Config) SetOnPremises(val bool) {
- c.OnPremises = val
-}
-
-// SetAssetBaseURL sets static dir for the confi
-func (c *Config) SetAssetBaseURL(d string) {
- c.AssetBaseURL = d
+ return c, nil
}
// IsProd checks if the app environment is configured to be production.
@@ -171,31 +116,15 @@ func (c Config) IsProd() bool {
func validate(c Config) error {
if _, err := url.ParseRequestURI(c.WebURL); err != nil {
- return errors.Wrapf(ErrWebURLInvalid, "provided: '%s'", c.WebURL)
+ return errors.Wrapf(ErrWebURLInvalid, "'%s'", c.WebURL)
}
if c.Port == "" {
return ErrPortInvalid
}
- if c.DB.Host == "" {
- return ErrDBMissingHost
- }
- if c.DB.Port == "" {
- return ErrDBMissingPort
- }
- if c.DB.Name == "" {
- return ErrDBMissingName
- }
- if c.DB.User == "" {
- return ErrDBMissingUser
+ if c.DBPath == "" {
+ return ErrDBMissingPath
}
return nil
}
-
-// GetConnectionStr returns a postgres connection string.
-func (c PostgresConfig) GetConnectionStr() string {
- return fmt.Sprintf(
- "sslmode=%s host=%s port=%s dbname=%s user=%s password=%s",
- c.SSLMode, c.Host, c.Port, c.Name, c.User, c.Password)
-}
diff --git a/pkg/server/config/config_test.go b/pkg/server/config/config_test.go
index 9f57b404..802a3add 100644
--- a/pkg/server/config/config_test.go
+++ b/pkg/server/config/config_test.go
@@ -33,12 +33,7 @@ func TestValidate(t *testing.T) {
}{
{
config: Config{
- DB: PostgresConfig{
- Host: "mockHost",
- Port: "5432",
- Name: "mockDB",
- User: "mockUser",
- },
+ DBPath: "test.db",
WebURL: "http://mock.url",
Port: "3000",
},
@@ -46,71 +41,21 @@ func TestValidate(t *testing.T) {
},
{
config: Config{
- DB: PostgresConfig{
- Port: "5432",
- Name: "mockDB",
- User: "mockUser",
- },
+ DBPath: "",
WebURL: "http://mock.url",
Port: "3000",
},
- expectedErr: ErrDBMissingHost,
+ expectedErr: ErrDBMissingPath,
},
{
config: Config{
- DB: PostgresConfig{
- Host: "mockHost",
- Name: "mockDB",
- User: "mockUser",
- },
- WebURL: "http://mock.url",
- Port: "3000",
- },
- expectedErr: ErrDBMissingPort,
- },
- {
- config: Config{
- DB: PostgresConfig{
- Host: "mockHost",
- Port: "5432",
- User: "mockUser",
- },
- WebURL: "http://mock.url",
- Port: "3000",
- },
- expectedErr: ErrDBMissingName,
- },
- {
- config: Config{
- DB: PostgresConfig{
- Host: "mockHost",
- Port: "5432",
- Name: "mockDB",
- },
- WebURL: "http://mock.url",
- Port: "3000",
- },
- expectedErr: ErrDBMissingUser,
- },
- {
- config: Config{
- DB: PostgresConfig{
- Host: "mockHost",
- Port: "5432",
- Name: "mockDB",
- User: "mockUser",
- },
+ DBPath: "test.db",
},
expectedErr: ErrWebURLInvalid,
},
{
config: Config{
- DB: PostgresConfig{
- Host: "mockHost",
- Port: "5432",
- Name: "mockDB",
- User: "mockUser",
- },
+ DBPath: "test.db",
WebURL: "http://mock.url",
},
expectedErr: ErrPortInvalid,
diff --git a/pkg/server/controllers/books_test.go b/pkg/server/controllers/books_test.go
index 6f0e3704..d59dcd17 100644
--- a/pkg/server/controllers/books_test.go
+++ b/pkg/server/controllers/books_test.go
@@ -21,69 +21,78 @@ package controllers
import (
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"testing"
+ "time"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
+// truncateMicro rounds time to microsecond precision to match SQLite storage
+func truncateMicro(t time.Time) time.Time {
+ return t.Round(time.Microsecond)
+}
+
func TestGetBooks(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- anotherUser := testutils.SetupUserData()
- testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ anotherUser := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
USN: 1123,
Deleted: false,
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "css",
USN: 1125,
Deleted: false,
}
- testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
+ testutils.MustExec(t, db.Save(&b2), "preparing b2")
b3 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "css",
USN: 1128,
Deleted: false,
}
- testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
+ testutils.MustExec(t, db.Save(&b3), "preparing b3")
b4 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "",
USN: 1129,
Deleted: true,
}
- testutils.MustExec(t, testutils.DB.Save(&b4), "preparing b4")
+ testutils.MustExec(t, db.Save(&b4), "preparing b4")
// Execute
endpoint := "/api/v3/books"
req := testutils.MakeReq(server.URL, "GET", endpoint, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -94,66 +103,75 @@ func TestGetBooks(t *testing.T) {
}
var b1Record, b2Record database.Book
- testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
+ testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
+ testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
+ testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
expected := []presenters.Book{
{
UUID: b2Record.UUID,
- CreatedAt: b2Record.CreatedAt,
- UpdatedAt: b2Record.UpdatedAt,
+ CreatedAt: truncateMicro(b2Record.CreatedAt),
+ UpdatedAt: truncateMicro(b2Record.UpdatedAt),
Label: b2Record.Label,
USN: b2Record.USN,
},
{
UUID: b1Record.UUID,
- CreatedAt: b1Record.CreatedAt,
- UpdatedAt: b1Record.UpdatedAt,
+ CreatedAt: truncateMicro(b1Record.CreatedAt),
+ UpdatedAt: truncateMicro(b1Record.UpdatedAt),
Label: b1Record.Label,
USN: b1Record.USN,
},
}
+ // Truncate payload timestamps to match SQLite precision
+ for i := range payload {
+ payload[i].CreatedAt = truncateMicro(payload[i].CreatedAt)
+ payload[i].UpdatedAt = truncateMicro(payload[i].UpdatedAt)
+ }
+
assert.DeepEqual(t, payload, expected, "payload mismatch")
}
func TestGetBooksByName(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- anotherUser := testutils.SetupUserData()
- testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ anotherUser := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
+ testutils.MustExec(t, db.Save(&b2), "preparing b2")
b3 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
+ testutils.MustExec(t, db.Save(&b3), "preparing b3")
// Execute
endpoint := "/api/v3/books?name=js"
req := testutils.MakeReq(server.URL, "GET", endpoint, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -164,56 +182,64 @@ func TestGetBooksByName(t *testing.T) {
}
var b1Record database.Book
- testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
+ testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
expected := []presenters.Book{
{
UUID: b1Record.UUID,
- CreatedAt: b1Record.CreatedAt,
- UpdatedAt: b1Record.UpdatedAt,
+ CreatedAt: truncateMicro(b1Record.CreatedAt),
+ UpdatedAt: truncateMicro(b1Record.UpdatedAt),
Label: b1Record.Label,
USN: b1Record.USN,
},
}
+ for i := range payload {
+ payload[i].CreatedAt = truncateMicro(payload[i].CreatedAt)
+ payload[i].UpdatedAt = truncateMicro(payload[i].UpdatedAt)
+ }
+
assert.DeepEqual(t, payload, expected, "payload mismatch")
}
func TestGetBook(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- anotherUser := testutils.SetupUserData()
- testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ anotherUser := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
+ testutils.MustExec(t, db.Save(&b2), "preparing b2")
b3 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
+ testutils.MustExec(t, db.Save(&b3), "preparing b3")
// Execute
endpoint := fmt.Sprintf("/api/v3/books/%s", b1.UUID)
req := testutils.MakeReq(server.URL, "GET", endpoint, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -224,49 +250,53 @@ func TestGetBook(t *testing.T) {
}
var b1Record database.Book
- testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
+ testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
expected := presenters.Book{
UUID: b1Record.UUID,
- CreatedAt: b1Record.CreatedAt,
- UpdatedAt: b1Record.UpdatedAt,
+ CreatedAt: truncateMicro(b1Record.CreatedAt),
+ UpdatedAt: truncateMicro(b1Record.UpdatedAt),
Label: b1Record.Label,
USN: b1Record.USN,
}
+ payload.CreatedAt = truncateMicro(payload.CreatedAt)
+ payload.UpdatedAt = truncateMicro(payload.UpdatedAt)
+
assert.DeepEqual(t, payload, expected, "payload mismatch")
}
func TestGetBookNonOwner(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- nonOwner := testutils.SetupUserData()
- testutils.SetupAccountData(nonOwner, "bob@test.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ nonOwner := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, nonOwner, "bob@test.com", "pass1234")
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
// Execute
endpoint := fmt.Sprintf("/api/v3/books/%s", b1.UUID)
req := testutils.MakeReq(server.URL, "GET", endpoint, "")
- res := testutils.HTTPAuthDo(t, req, nonOwner)
+ res := testutils.HTTPAuthDo(t, db, req, nonOwner)
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@@ -275,23 +305,23 @@ func TestGetBookNonOwner(t *testing.T) {
func TestCreateBook(t *testing.T) {
t.Run("success", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`)
// Execute
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
@@ -299,10 +329,10 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var bookCount, noteCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, db.First(&bookRecord), "finding book")
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
maxUSN := 102
@@ -323,39 +353,43 @@ func TestCreateBook(t *testing.T) {
Book: presenters.Book{
UUID: bookRecord.UUID,
USN: bookRecord.USN,
- CreatedAt: bookRecord.CreatedAt,
- UpdatedAt: bookRecord.UpdatedAt,
+ CreatedAt: truncateMicro(bookRecord.CreatedAt),
+ UpdatedAt: truncateMicro(bookRecord.UpdatedAt),
Label: "js",
},
}
+ got.Book.CreatedAt = truncateMicro(got.Book.CreatedAt)
+ got.Book.UpdatedAt = truncateMicro(got.Book.UpdatedAt)
+
assert.DeepEqual(t, got, expected, "payload mismatch")
})
t.Run("duplicate", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
USN: 58,
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book data")
+ testutils.MustExec(t, db.Save(&b1), "preparing book data")
// Execute
req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`)
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusConflict, "")
@@ -363,10 +397,10 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var bookCount, noteCount int64
var userRecord database.User
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, db.First(&bookRecord), "finding book")
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(1), "book count mismatch")
assert.Equalf(t, noteCount, int64(0), "note count mismatch")
@@ -422,18 +456,18 @@ func TestUpdateBook(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: tc.bookUUID,
@@ -441,18 +475,18 @@ func TestUpdateBook(t *testing.T) {
Label: tc.bookLabel,
Deleted: tc.bookDeleted,
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: b2UUID,
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
+ testutils.MustExec(t, db.Save(&b2), "preparing b2")
// Execute
endpoint := fmt.Sprintf("/api/v3/books/%s", tc.bookUUID)
req := testutils.MakeReq(server.URL, "PATCH", endpoint, tc.payload.ToJSON(t))
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, fmt.Sprintf("status code mismatch for test case %d", idx))
@@ -460,10 +494,10 @@ func TestUpdateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var noteCount, bookCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(2), "book count mismatch")
assert.Equalf(t, noteCount, int64(0), "note count mismatch")
@@ -507,41 +541,44 @@ func TestDeleteBook(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 58), "preparing user max_usn")
- anotherUser := testutils.SetupUserData()
- testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
- testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn")
+ anotherUser := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
+ testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
USN: 1,
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing a book data")
+ testutils.MustExec(t, db.Save(&b1), "preparing a book data")
b2 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: tc.label,
USN: 2,
Deleted: tc.deleted,
}
- testutils.MustExec(t, testutils.DB.Save(&b2), "preparing a book data")
+ testutils.MustExec(t, db.Save(&b2), "preparing a book data")
b3 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "linux",
USN: 3,
}
- testutils.MustExec(t, testutils.DB.Save(&b3), "preparing a book data")
+ testutils.MustExec(t, db.Save(&b3), "preparing a book data")
var n2Body string
if !tc.deleted {
@@ -553,49 +590,54 @@ func TestDeleteBook(t *testing.T) {
}
n1 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n1 content",
USN: 4,
}
- testutils.MustExec(t, testutils.DB.Save(&n1), "preparing a note data")
+ testutils.MustExec(t, db.Save(&n1), "preparing a note data")
n2 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b2.UUID,
Body: n2Body,
USN: 5,
Deleted: tc.deleted,
}
- testutils.MustExec(t, testutils.DB.Save(&n2), "preparing a note data")
+ testutils.MustExec(t, db.Save(&n2), "preparing a note data")
n3 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b2.UUID,
Body: n3Body,
USN: 6,
Deleted: tc.deleted,
}
- testutils.MustExec(t, testutils.DB.Save(&n3), "preparing a note data")
+ testutils.MustExec(t, db.Save(&n3), "preparing a note data")
n4 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b2.UUID,
Body: "",
USN: 7,
Deleted: true,
}
- testutils.MustExec(t, testutils.DB.Save(&n4), "preparing a note data")
+ testutils.MustExec(t, db.Save(&n4), "preparing a note data")
n5 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
BookUUID: b3.UUID,
Body: "n5 content",
USN: 8,
}
- testutils.MustExec(t, testutils.DB.Save(&n5), "preparing a note data")
+ testutils.MustExec(t, db.Save(&n5), "preparing a note data")
// Execute
endpoint := fmt.Sprintf("/api/v3/books/%s", b2.UUID)
req := testutils.MakeReq(server.URL, "DELETE", endpoint, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -605,17 +647,17 @@ func TestDeleteBook(t *testing.T) {
var userRecord database.User
var bookCount, noteCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
- testutils.MustExec(t, testutils.DB.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
- testutils.MustExec(t, testutils.DB.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
- testutils.MustExec(t, testutils.DB.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
- testutils.MustExec(t, testutils.DB.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
- testutils.MustExec(t, testutils.DB.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
+ testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
+ testutils.MustExec(t, db.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
+ testutils.MustExec(t, db.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
+ testutils.MustExec(t, db.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
+ testutils.MustExec(t, db.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
+ testutils.MustExec(t, db.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
+ testutils.MustExec(t, db.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equal(t, bookCount, int64(3), "book count mismatch")
assert.Equal(t, noteCount, int64(5), "note count mismatch")
diff --git a/pkg/server/controllers/health_test.go b/pkg/server/controllers/health_test.go
index a1d0294e..cec8d35a 100644
--- a/pkg/server/controllers/health_test.go
+++ b/pkg/server/controllers/health_test.go
@@ -24,16 +24,15 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/server/app"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestHealth(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- server := MustNewServer(t, &app.App{
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
// Execute
diff --git a/pkg/server/controllers/helpers.go b/pkg/server/controllers/helpers.go
index ba697be0..0e88034b 100644
--- a/pkg/server/controllers/helpers.go
+++ b/pkg/server/controllers/helpers.go
@@ -61,13 +61,6 @@ func parseForm(r *http.Request, dst interface{}) error {
return parseValues(r.PostForm, dst)
}
-func parseURLParams(r *http.Request, dst interface{}) error {
- if err := r.ParseForm(); err != nil {
- return err
- }
- return parseValues(r.Form, dst)
-}
-
func parseValues(values url.Values, dst interface{}) error {
dec := schema.NewDecoder()
diff --git a/pkg/server/controllers/main_test.go b/pkg/server/controllers/main_test.go
index 7e48d502..35eef6da 100644
--- a/pkg/server/controllers/main_test.go
+++ b/pkg/server/controllers/main_test.go
@@ -21,15 +21,13 @@ package controllers
import (
"os"
"testing"
-
- "github.com/dnote/dnote/pkg/server/testutils"
+ "time"
)
func TestMain(m *testing.M) {
- testutils.InitTestDB()
+ // Set timezone to UTC to match database timestamps
+ time.Local = time.UTC
code := m.Run()
- testutils.ClearData(testutils.DB)
-
os.Exit(code)
}
diff --git a/pkg/server/controllers/notes.go b/pkg/server/controllers/notes.go
index 0a12c071..a7434366 100644
--- a/pkg/server/controllers/notes.go
+++ b/pkg/server/controllers/notes.go
@@ -19,13 +19,10 @@
package controllers
import (
- "math"
"net/http"
"net/url"
- "sort"
"strconv"
"strings"
- "time"
"github.com/dnote/dnote/pkg/server/app"
"github.com/dnote/dnote/pkg/server/context"
@@ -150,69 +147,6 @@ func (n *Notes) getNotes(r *http.Request) (app.GetNotesResult, app.GetNotesParam
return res, p, nil
}
-type noteGroup struct {
- Year int
- Month int
- Data []database.Note
-}
-
-type bucketKey struct {
- year int
- month time.Month
-}
-
-func groupNotes(notes []database.Note) []noteGroup {
- ret := []noteGroup{}
-
- buckets := map[bucketKey][]database.Note{}
-
- for _, note := range notes {
- year := note.UpdatedAt.Year()
- month := note.UpdatedAt.Month()
- key := bucketKey{year, month}
-
- if _, ok := buckets[key]; !ok {
- buckets[key] = []database.Note{}
- }
-
- buckets[key] = append(buckets[key], note)
- }
-
- keys := []bucketKey{}
- for key := range buckets {
- keys = append(keys, key)
- }
-
- sort.Slice(keys, func(i, j int) bool {
- yearI := keys[i].year
- yearJ := keys[j].year
- monthI := keys[i].month
- monthJ := keys[j].month
-
- if yearI == yearJ {
- return monthI < monthJ
- }
-
- return yearI < yearJ
- })
-
- for _, key := range keys {
- group := noteGroup{
- Year: key.year,
- Month: int(key.month),
- Data: buckets[key],
- }
- ret = append(ret, group)
- }
-
- return ret
-}
-
-func getMaxPage(page, total int) int {
- tmp := float64(total) / float64(notesPerPage)
- return int(math.Ceil(tmp))
-}
-
// GetNotesResponse is a reponse by getNotesHandler
type GetNotesResponse struct {
Notes []presenters.Note `json:"notes"`
diff --git a/pkg/server/controllers/notes_test.go b/pkg/server/controllers/notes_test.go
index caeaa775..6bf3c06b 100644
--- a/pkg/server/controllers/notes_test.go
+++ b/pkg/server/controllers/notes_test.go
@@ -21,7 +21,7 @@ package controllers
import (
"encoding/json"
"fmt"
- "io/ioutil"
+ "io"
"net/http"
"testing"
"time"
@@ -29,7 +29,6 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
@@ -39,8 +38,8 @@ import (
func getExpectedNotePayload(n database.Note, b database.Book, u database.User) presenters.Note {
return presenters.Note{
UUID: n.UUID,
- CreatedAt: n.CreatedAt,
- UpdatedAt: n.UpdatedAt,
+ CreatedAt: truncateMicro(n.CreatedAt),
+ UpdatedAt: truncateMicro(n.UpdatedAt),
Body: n.Body,
AddedOn: n.AddedOn,
Public: n.Public,
@@ -56,37 +55,41 @@ func getExpectedNotePayload(n database.Note, b database.Book, u database.User) p
}
func TestGetNotes(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- anotherUser := testutils.SetupUserData()
- testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ anotherUser := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
+ testutils.MustExec(t, db.Save(&b2), "preparing b2")
b3 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "css",
}
- testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
+ testutils.MustExec(t, db.Save(&b3), "preparing b3")
n1 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n1 content",
@@ -94,8 +97,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
+ testutils.MustExec(t, db.Save(&n1), "preparing n1")
n2 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n2 content",
@@ -103,8 +107,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 11, 22, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
+ testutils.MustExec(t, db.Save(&n2), "preparing n2")
n3 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n3 content",
@@ -112,8 +117,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2017, time.January, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
+ testutils.MustExec(t, db.Save(&n3), "preparing n3")
n4 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b2.UUID,
Body: "n4 content",
@@ -121,8 +127,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.September, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, testutils.DB.Save(&n4), "preparing n4")
+ testutils.MustExec(t, db.Save(&n4), "preparing n4")
n5 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
BookUUID: b3.UUID,
Body: "n5 content",
@@ -130,8 +137,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, testutils.DB.Save(&n5), "preparing n5")
+ testutils.MustExec(t, db.Save(&n5), "preparing n5")
n6 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "",
@@ -139,13 +147,13 @@ func TestGetNotes(t *testing.T) {
Deleted: true,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
- testutils.MustExec(t, testutils.DB.Save(&n6), "preparing n6")
+ testutils.MustExec(t, db.Save(&n6), "preparing n6")
// Execute
endpoint := "/api/v3/notes"
req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("%s?year=2018&month=8", endpoint), "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -156,8 +164,8 @@ func TestGetNotes(t *testing.T) {
}
var n2Record, n1Record database.Note
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
+ testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
expected := GetNotesResponse{
Notes: []presenters.Note{
@@ -171,44 +179,48 @@ func TestGetNotes(t *testing.T) {
}
func TestGetNote(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- anotherUser := testutils.SetupUserData()
+ user := testutils.SetupUserData(db)
+ anotherUser := testutils.SetupUserData(db)
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
privateNote := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "privateNote content",
Public: false,
}
- testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
+ testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "publicNote content",
Public: true,
}
- testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing publicNote")
+ testutils.MustExec(t, db.Save(&publicNote), "preparing publicNote")
deletedNote := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Deleted: true,
}
- testutils.MustExec(t, testutils.DB.Save(&deletedNote), "preparing deletedNote")
+ testutils.MustExec(t, db.Save(&deletedNote), "preparing deletedNote")
getURL := func(noteUUID string) string {
return fmt.Sprintf("/api/v3/notes/%s", noteUUID)
@@ -218,7 +230,7 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(publicNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -229,7 +241,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@@ -239,7 +251,7 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(publicNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -250,7 +262,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@@ -260,7 +272,7 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(publicNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
- res := testutils.HTTPAuthDo(t, req, anotherUser)
+ res := testutils.HTTPAuthDo(t, db, req, anotherUser)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -271,7 +283,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@@ -281,12 +293,12 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(privateNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
- res := testutils.HTTPAuthDo(t, req, anotherUser)
+ res := testutils.HTTPAuthDo(t, db, req, anotherUser)
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@@ -309,7 +321,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
+ testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@@ -324,7 +336,7 @@ func TestGetNote(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@@ -336,12 +348,12 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL("somerandomstring")
req := testutils.MakeReq(server.URL, "GET", url, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@@ -353,12 +365,12 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(deletedNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
- body, err := ioutil.ReadAll(res.Body)
+ body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@@ -368,31 +380,32 @@ func TestGetNote(t *testing.T) {
}
func TestCreateNote(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
USN: 58,
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
// Execute
dat := fmt.Sprintf(`{"book_uuid": "%s", "content": "note content"}`, b1.UUID)
req := testutils.MakeReq(server.URL, "POST", "/api/v3/notes", dat)
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
@@ -401,11 +414,11 @@ func TestCreateNote(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var bookCount, noteCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, testutils.DB.First(¬eRecord), "finding note")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, db.First(¬eRecord), "finding note")
+ testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(1), "book count mismatch")
assert.Equalf(t, noteCount, int64(1), "note count mismatch")
@@ -449,38 +462,39 @@ func TestDeleteNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 981), "preparing user max_usn")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn")
b1 := database.Book{
UUID: b1UUID,
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
note := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: tc.content,
Deleted: tc.deleted,
USN: tc.originalUSN,
}
- testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note")
+ testutils.MustExec(t, db.Save(¬e), "preparing note")
// Execute
endpoint := fmt.Sprintf("/api/v3/notes/%s", note.UUID)
req := testutils.MakeReq(server.URL, "DELETE", endpoint, "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@@ -489,11 +503,11 @@ func TestDeleteNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
var bookCount, noteCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
+ testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(1), "book count mismatch")
assert.Equalf(t, noteCount, int64(1), "note count mismatch")
@@ -687,42 +701,42 @@ func TestUpdateNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.DB = db
+ a.Clock = clock.NewMock()
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
- testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
+ testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: b1UUID,
UserID: user.ID,
Label: "css",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: b2UUID,
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
+ testutils.MustExec(t, db.Save(&b2), "preparing b2")
note := database.Note{
- UserID: user.ID,
UUID: tc.noteUUID,
+ UserID: user.ID,
BookUUID: tc.noteBookUUID,
Body: tc.noteBody,
Deleted: tc.noteDeleted,
Public: tc.notePublic,
}
- testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note")
+ testutils.MustExec(t, db.Save(¬e), "preparing note")
// Execute
var req *http.Request
@@ -730,7 +744,7 @@ func TestUpdateNote(t *testing.T) {
endpoint := fmt.Sprintf("/api/v3/notes/%s", note.UUID)
req = testutils.MakeReq(server.URL, "PATCH", endpoint, tc.payload.ToJSON(t))
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "status code mismatch for test case")
@@ -739,11 +753,11 @@ func TestUpdateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
var noteCount, bookCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
- testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes")
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
- testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
- testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
+ testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
+ testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes")
+ testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note")
+ testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
+ testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(2), "book count mismatch")
assert.Equalf(t, noteCount, int64(1), "note count mismatch")
diff --git a/pkg/server/controllers/routes.go b/pkg/server/controllers/routes.go
index 3e80a5ed..0c873fa3 100644
--- a/pkg/server/controllers/routes.go
+++ b/pkg/server/controllers/routes.go
@@ -48,25 +48,25 @@ func NewWebRoutes(a *app.App, c *Controllers) []Route {
redirectGuest := &mw.AuthParams{RedirectGuestsToLogin: true}
ret := []Route{
- {"GET", "/", mw.Auth(a, c.Users.Settings, redirectGuest), true},
- {"GET", "/about", mw.Auth(a, c.Users.About, redirectGuest), true},
- {"GET", "/login", mw.GuestOnly(a, c.Users.NewLogin), true},
- {"POST", "/login", mw.GuestOnly(a, c.Users.Login), true},
+ {"GET", "/", mw.Auth(a.DB, c.Users.Settings, redirectGuest), true},
+ {"GET", "/about", mw.Auth(a.DB, c.Users.About, redirectGuest), true},
+ {"GET", "/login", mw.GuestOnly(a.DB, c.Users.NewLogin), true},
+ {"POST", "/login", mw.GuestOnly(a.DB, c.Users.Login), true},
{"POST", "/logout", c.Users.Logout, true},
{"GET", "/password-reset", c.Users.PasswordResetView.ServeHTTP, true},
{"PATCH", "/password-reset", c.Users.PasswordReset, true},
{"GET", "/password-reset/{token}", c.Users.PasswordResetConfirm, true},
{"POST", "/reset-token", c.Users.CreateResetToken, true},
- {"POST", "/verification-token", mw.Auth(a, c.Users.CreateEmailVerificationToken, redirectGuest), true},
- {"GET", "/verify-email/{token}", mw.Auth(a, c.Users.VerifyEmail, redirectGuest), true},
- {"PATCH", "/account/profile", mw.Auth(a, c.Users.ProfileUpdate, nil), true},
- {"PATCH", "/account/password", mw.Auth(a, c.Users.PasswordUpdate, nil), true},
+ {"POST", "/verification-token", mw.Auth(a.DB, c.Users.CreateEmailVerificationToken, redirectGuest), true},
+ {"GET", "/verify-email/{token}", mw.Auth(a.DB, c.Users.VerifyEmail, redirectGuest), true},
+ {"PATCH", "/account/profile", mw.Auth(a.DB, c.Users.ProfileUpdate, nil), true},
+ {"PATCH", "/account/password", mw.Auth(a.DB, c.Users.PasswordUpdate, nil), true},
{"GET", "/health", c.Health.Index, true},
}
- if !a.Config.DisableRegistration {
+ if !a.DisableRegistration {
ret = append(ret, Route{"GET", "/join", c.Users.New, true})
ret = append(ret, Route{"POST", "/join", c.Users.Create, true})
}
@@ -76,28 +76,25 @@ func NewWebRoutes(a *app.App, c *Controllers) []Route {
// NewAPIRoutes returns a new api routes
func NewAPIRoutes(a *app.App, c *Controllers) []Route {
-
- proOnly := mw.AuthParams{ProOnly: true}
-
return []Route{
// v3
- {"GET", "/v3/sync/fragment", mw.Cors(mw.Auth(a, c.Sync.GetSyncFragment, &proOnly)), false},
- {"GET", "/v3/sync/state", mw.Cors(mw.Auth(a, c.Sync.GetSyncState, &proOnly)), false},
- {"POST", "/v3/signin", mw.Cors(c.Users.V3Login), true},
- {"POST", "/v3/signout", mw.Cors(c.Users.V3Logout), true},
- {"OPTIONS", "/v3/signout", mw.Cors(c.Users.logoutOptions), true},
- {"GET", "/v3/notes", mw.Cors(mw.Auth(a, c.Notes.V3Index, nil)), true},
+ {"GET", "/v3/sync/fragment", mw.Auth(a.DB, c.Sync.GetSyncFragment, nil), false},
+ {"GET", "/v3/sync/state", mw.Auth(a.DB, c.Sync.GetSyncState, nil), false},
+ {"POST", "/v3/signin", c.Users.V3Login, true},
+ {"POST", "/v3/signout", c.Users.V3Logout, true},
+ {"OPTIONS", "/v3/signout", c.Users.logoutOptions, true},
+ {"GET", "/v3/notes", mw.Auth(a.DB, c.Notes.V3Index, nil), true},
{"GET", "/v3/notes/{noteUUID}", c.Notes.V3Show, true},
- {"POST", "/v3/notes", mw.Cors(mw.Auth(a, c.Notes.V3Create, nil)), true},
- {"DELETE", "/v3/notes/{noteUUID}", mw.Cors(mw.Auth(a, c.Notes.V3Delete, nil)), true},
- {"PATCH", "/v3/notes/{noteUUID}", mw.Cors(mw.Auth(a, c.Notes.V3Update, nil)), true},
- {"OPTIONS", "/v3/notes", mw.Cors(c.Notes.IndexOptions), true},
- {"GET", "/v3/books", mw.Cors(mw.Auth(a, c.Books.V3Index, nil)), true},
- {"GET", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Show, nil)), true},
- {"POST", "/v3/books", mw.Cors(mw.Auth(a, c.Books.V3Create, nil)), true},
- {"PATCH", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Update, nil)), true},
- {"DELETE", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Delete, nil)), true},
- {"OPTIONS", "/v3/books", mw.Cors(c.Books.IndexOptions), true},
+ {"POST", "/v3/notes", mw.Auth(a.DB, c.Notes.V3Create, nil), true},
+ {"DELETE", "/v3/notes/{noteUUID}", mw.Auth(a.DB, c.Notes.V3Delete, nil), true},
+ {"PATCH", "/v3/notes/{noteUUID}", mw.Auth(a.DB, c.Notes.V3Update, nil), true},
+ {"OPTIONS", "/v3/notes", c.Notes.IndexOptions, true},
+ {"GET", "/v3/books", mw.Auth(a.DB, c.Books.V3Index, nil), true},
+ {"GET", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Show, nil), true},
+ {"POST", "/v3/books", mw.Auth(a.DB, c.Books.V3Create, nil), true},
+ {"PATCH", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Update, nil), true},
+ {"DELETE", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Delete, nil), true},
+ {"OPTIONS", "/v3/books", c.Books.IndexOptions, true},
}
}
diff --git a/pkg/server/controllers/routes_test.go b/pkg/server/controllers/routes_test.go
index 39b151cf..d3084aa3 100644
--- a/pkg/server/controllers/routes_test.go
+++ b/pkg/server/controllers/routes_test.go
@@ -25,7 +25,6 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/testutils"
)
@@ -56,10 +55,11 @@ func TestNotSupportedVersions(t *testing.T) {
}
// setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ db := testutils.InitMemoryDB(t)
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
for _, tc := range testCases {
diff --git a/pkg/server/controllers/testutils.go b/pkg/server/controllers/testutils.go
index 807f92d8..03a097c8 100644
--- a/pkg/server/controllers/testutils.go
+++ b/pkg/server/controllers/testutils.go
@@ -27,9 +27,9 @@ import (
)
// MustNewServer is a test utility function to initialize a new server
-// with the given app paratmers
-func MustNewServer(t *testing.T, appParams *app.App) *httptest.Server {
- server, err := NewServer(appParams)
+// with the given app
+func MustNewServer(t *testing.T, a *app.App) *httptest.Server {
+ server, err := NewServer(a)
if err != nil {
t.Fatal(errors.Wrap(err, "initializing router"))
}
@@ -37,16 +37,14 @@ func MustNewServer(t *testing.T, appParams *app.App) *httptest.Server {
return server
}
-func NewServer(appParams *app.App) (*httptest.Server, error) {
- a := app.NewTest(appParams)
-
- ctl := New(&a)
+func NewServer(a *app.App) (*httptest.Server, error) {
+ ctl := New(a)
rc := RouteConfig{
- WebRoutes: NewWebRoutes(&a, ctl),
- APIRoutes: NewAPIRoutes(&a, ctl),
+ WebRoutes: NewWebRoutes(a, ctl),
+ APIRoutes: NewAPIRoutes(a, ctl),
Controllers: ctl,
}
- r, err := NewRouter(&a, rc)
+ r, err := NewRouter(a, rc)
if err != nil {
return nil, errors.Wrap(err, "initializing router")
}
diff --git a/pkg/server/controllers/users_test.go b/pkg/server/controllers/users_test.go
index 4546a18a..e4c9ff2d 100644
--- a/pkg/server/controllers/users_test.go
+++ b/pkg/server/controllers/users_test.go
@@ -30,18 +30,18 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
+ "gorm.io/gorm"
)
-func assertResponseSessionCookie(t *testing.T, res *http.Response) {
+func assertResponseSessionCookie(t *testing.T, db *gorm.DB, res *http.Response) {
var sessionCount int64
var session database.Session
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
- testutils.MustExec(t, testutils.DB.First(&session), "getting session")
+ testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, db.First(&session), "getting session")
c := testutils.GetCookieByName(res.Cookies(), "id")
assert.Equal(t, c.Value, session.Key, "session key mismatch")
@@ -55,53 +55,35 @@ func TestJoin(t *testing.T) {
email string
password string
passwordConfirmation string
- onPremises bool
- expectedPro bool
}{
{
email: "alice@example.com",
password: "pass1234",
passwordConfirmation: "pass1234",
- onPremises: false,
- expectedPro: false,
},
{
email: "bob@example.com",
password: "Y9EwmjH@Jq6y5a64MSACUoM4w7SAhzvY",
passwordConfirmation: "Y9EwmjH@Jq6y5a64MSACUoM4w7SAhzvY",
- onPremises: false,
- expectedPro: false,
},
{
email: "chuck@example.com",
password: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC",
passwordConfirmation: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC",
- onPremises: false,
- expectedPro: false,
- },
- // on premise
- {
- email: "dan@example.com",
- password: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC",
- passwordConfirmation: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC",
- onPremises: true,
- expectedPro: true,
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("register %s %s", tc.email, tc.password), func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
emailBackend := testutils.MockEmailbackendImplementation{}
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- EmailBackend: &emailBackend,
- Config: config.Config{
- OnPremises: tc.onPremises,
- },
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.EmailBackend = &emailBackend
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
dat := url.Values{}
@@ -117,15 +99,14 @@ func TestJoin(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusFound, "")
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("email = ?", tc.email).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("email = ?", tc.email).First(&account), "finding account")
assert.Equal(t, account.Email.String, tc.email, "Email mismatch")
assert.NotEqual(t, account.UserID, 0, "UserID mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password))
assert.Equal(t, passwordErr, nil, "Password mismatch")
var user database.User
- testutils.MustExec(t, testutils.DB.Where("id = ?", account.UserID).First(&user), "finding user")
- assert.Equal(t, user.Cloud, tc.expectedPro, "Cloud mismatch")
+ testutils.MustExec(t, db.Where("id = ?", account.UserID).First(&user), "finding user")
assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch")
// welcome email
@@ -133,20 +114,20 @@ func TestJoin(t *testing.T) {
assert.DeepEqual(t, emailBackend.Emails[0].To, []string{tc.email}, "email to mismatch")
// after register, should sign in user
- assertResponseSessionCookie(t, res)
+ assertResponseSessionCookie(t, db, res)
})
}
}
func TestJoinError(t *testing.T) {
t.Run("missing email", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
dat := url.Values{}
@@ -160,21 +141,21 @@ func TestJoinError(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
assert.Equal(t, userCount, int64(0), "userCount mismatch")
})
t.Run("missing password", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
dat := url.Values{}
@@ -188,21 +169,21 @@ func TestJoinError(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
assert.Equal(t, userCount, int64(0), "userCount mismatch")
})
t.Run("password confirmation mismatch", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
dat := url.Values{}
@@ -218,8 +199,8 @@ func TestJoinError(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch")
var accountCount, userCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, int64(0), "accountCount mismatch")
assert.Equal(t, userCount, int64(0), "userCount mismatch")
@@ -227,17 +208,17 @@ func TestJoinError(t *testing.T) {
}
func TestJoinDuplicateEmail(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ u := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
dat := url.Values{}
dat.Set("email", "alice@example.com")
@@ -252,12 +233,12 @@ func TestJoinDuplicateEmail(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch")
var accountCount, userCount, verificationTokenCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
+ testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token")
var user database.User
- testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
+ testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
assert.Equal(t, accountCount, int64(1), "account count mismatch")
assert.Equal(t, userCount, int64(1), "user count mismatch")
@@ -266,15 +247,14 @@ func TestJoinDuplicateEmail(t *testing.T) {
}
func TestJoinDisabled(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{
- DisableRegistration: true,
- },
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ a.DisableRegistration = true
+ server := MustNewServer(t, &a)
defer server.Close()
dat := url.Values{}
@@ -289,8 +269,8 @@ func TestJoinDisabled(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusNotFound, "status code mismatch")
var accountCount, userCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
+ testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
+ testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, accountCount, int64(0), "account count mismatch")
assert.Equal(t, userCount, int64(0), "user count mismatch")
@@ -298,16 +278,16 @@ func TestJoinDisabled(t *testing.T) {
func TestLogin(t *testing.T) {
testutils.RunForWebAndAPI(t, "success", func(t *testing.T, target testutils.EndpointType) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
- u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "pass1234")
+ u := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
defer server.Close()
// Execute
@@ -332,11 +312,11 @@ func TestLogin(t *testing.T) {
}
var user database.User
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
+ testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
assert.NotEqual(t, user.LastLoginAt, nil, "LastLoginAt mismatch")
if target == testutils.EndpointWeb {
- assertResponseSessionCookie(t, res)
+ assertResponseSessionCookie(t, db, res)
} else {
// after register, should sign in user
var got SessionResponse
@@ -346,28 +326,28 @@ func TestLogin(t *testing.T) {
var sessionCount int64
var session database.Session
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
- testutils.MustExec(t, testutils.DB.First(&session), "getting session")
+ testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, db.First(&session), "getting session")
assert.Equal(t, sessionCount, int64(1), "sessionCount mismatch")
assert.Equal(t, got.Key, session.Key, "session Key mismatch")
assert.Equal(t, got.ExpiresAt, session.ExpiresAt.Unix(), "session ExpiresAt mismatch")
- assertResponseSessionCookie(t, res)
+ assertResponseSessionCookie(t, db, res)
}
})
testutils.RunForWebAndAPI(t, "wrong password", func(t *testing.T, target testutils.EndpointType) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
- u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "pass1234")
+ u := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
defer server.Close()
var req *http.Request
@@ -388,26 +368,26 @@ func TestLogin(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var user database.User
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
+ testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
var sessionCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, int64(0), "sessionCount mismatch")
})
testutils.RunForWebAndAPI(t, "wrong email", func(t *testing.T, target testutils.EndpointType) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "pass1234")
+ u := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
var req *http.Request
if target == testutils.EndpointWeb {
@@ -427,22 +407,22 @@ func TestLogin(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var user database.User
- testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user")
+ testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user")
assert.DeepEqual(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch")
var sessionCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, int64(0), "sessionCount mismatch")
})
testutils.RunForWebAndAPI(t, "nonexistent email", func(t *testing.T, target testutils.EndpointType) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
var req *http.Request
@@ -463,22 +443,22 @@ func TestLogin(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "")
var sessionCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
assert.Equal(t, sessionCount, int64(0), "sessionCount mismatch")
})
}
func TestLogout(t *testing.T) {
- setupLogoutTest := func(t *testing.T) (*httptest.Server, *database.Session, *database.Session) {
+ setupLogoutTest := func(t *testing.T, db *gorm.DB) (*httptest.Server, *database.Session, *database.Session) {
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
- aliceUser := testutils.SetupUserData()
- testutils.SetupAccountData(aliceUser, "alice@example.com", "pass1234")
- anotherUser := testutils.SetupUserData()
+ aliceUser := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, aliceUser, "alice@example.com", "pass1234")
+ anotherUser := testutils.SetupUserData(db)
session1ExpiresAt := time.Now().Add(time.Hour * 24)
session1 := database.Session{
@@ -486,21 +466,21 @@ func TestLogout(t *testing.T) {
UserID: aliceUser.ID,
ExpiresAt: session1ExpiresAt,
}
- testutils.MustExec(t, testutils.DB.Save(&session1), "preparing session1")
+ testutils.MustExec(t, db.Save(&session1), "preparing session1")
session2 := database.Session{
Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=",
UserID: anotherUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session2")
+ testutils.MustExec(t, db.Save(&session2), "preparing session2")
return server, &session1, &session2
}
testutils.RunForWebAndAPI(t, "authenticated", func(t *testing.T, target testutils.EndpointType) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- server, session1, _ := setupLogoutTest(t)
+ server, session1, _ := setupLogoutTest(t, db)
defer server.Close()
// Execute
@@ -525,8 +505,8 @@ func TestLogout(t *testing.T) {
var sessionCount int64
var s2 database.Session
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
- testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2")
+ testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2")
assert.Equal(t, sessionCount, int64(1), "sessionCount mismatch")
@@ -542,9 +522,9 @@ func TestLogout(t *testing.T) {
})
testutils.RunForWebAndAPI(t, "unauthenticated", func(t *testing.T, target testutils.EndpointType) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- server, _, _ := setupLogoutTest(t)
+ server, _, _ := setupLogoutTest(t, db)
defer server.Close()
// Execute
@@ -567,9 +547,9 @@ func TestLogout(t *testing.T) {
var sessionCount int64
var postSession1, postSession2 database.Session
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session")
- testutils.MustExec(t, testutils.DB.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1")
- testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2")
+ testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session")
+ testutils.MustExec(t, db.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1")
+ testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2")
// two existing sessions should remain
assert.Equal(t, sessionCount, int64(2), "sessionCount mismatch")
@@ -581,46 +561,46 @@ func TestLogout(t *testing.T) {
func TestResetPassword(t *testing.T) {
t.Run("success", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
otherTok := database.Token{
UserID: u.ID,
Value: "somerandomvalue",
Type: database.TokenTypeEmailVerification,
}
- testutils.MustExec(t, testutils.DB.Save(&otherTok), "preparing another token")
+ testutils.MustExec(t, db.Save(&otherTok), "preparing another token")
s1 := database.Session{
Key: "some-session-key-1",
UserID: u.ID,
ExpiresAt: time.Now().Add(time.Hour * 10 * 24),
}
- testutils.MustExec(t, testutils.DB.Save(&s1), "preparing user session 1")
+ testutils.MustExec(t, db.Save(&s1), "preparing user session 1")
s2 := &database.Session{
Key: "some-session-key-2",
UserID: u.ID,
ExpiresAt: time.Now().Add(time.Hour * 10 * 24),
}
- testutils.MustExec(t, testutils.DB.Save(&s2), "preparing user session 2")
+ testutils.MustExec(t, db.Save(&s2), "preparing user session 2")
- anotherUser := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Save(&database.Session{
+ anotherUser := testutils.SetupUserData(db)
+ testutils.MustExec(t, db.Save(&database.Session{
Key: "some-session-key-3",
UserID: anotherUser.ID,
ExpiresAt: time.Now().Add(time.Hour * 10 * 24),
@@ -640,9 +620,9 @@ func TestResetPassword(t *testing.T) {
var resetToken, verificationToken database.Token
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
- testutils.MustExec(t, testutils.DB.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token")
- testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
+ testutils.MustExec(t, db.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token")
+ testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account")
assert.NotEqual(t, resetToken.UsedAt, nil, "reset_token UsedAt mismatch")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
@@ -650,38 +630,38 @@ func TestResetPassword(t *testing.T) {
assert.Equal(t, verificationToken.UsedAt, (*time.Time)(nil), "verificationToken UsedAt mismatch")
var s1Count, s2Count int64
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Where("id = ?", s1.ID).Count(&s1Count), "counting s1")
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Where("id = ?", s2.ID).Count(&s2Count), "counting s2")
+ testutils.MustExec(t, db.Model(&database.Session{}).Where("id = ?", s1.ID).Count(&s1Count), "counting s1")
+ testutils.MustExec(t, db.Model(&database.Session{}).Where("id = ?", s2.ID).Count(&s2Count), "counting s2")
assert.Equal(t, s1Count, int64(0), "s1 should have been deleted")
assert.Equal(t, s2Count, int64(0), "s2 should have been deleted")
var userSessionCount, anotherUserSessionCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Where("user_id = ?", u.ID).Count(&userSessionCount), "counting user session")
- testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Where("user_id = ?", anotherUser.ID).Count(&anotherUserSessionCount), "counting anotherUser session")
+ testutils.MustExec(t, db.Model(&database.Session{}).Where("user_id = ?", u.ID).Count(&userSessionCount), "counting user session")
+ testutils.MustExec(t, db.Model(&database.Session{}).Where("user_id = ?", anotherUser.ID).Count(&anotherUserSessionCount), "counting anotherUser session")
assert.Equal(t, userSessionCount, int64(0), "should have deleted a user session")
assert.Equal(t, anotherUserSessionCount, int64(1), "anotherUser session count mismatch")
})
t.Run("nonexistent token", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
dat := url.Values{}
dat.Set("token", "-ApMnyvpg59uOU5b-Kf5uQ==")
@@ -697,33 +677,33 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
- testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token")
+ testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account")
- assert.Equal(t, a.Password, account.Password, "password should not have been updated")
- assert.Equal(t, a.Password, account.Password, "password should not have been updated")
+ assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
+ assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
})
t.Run("expired token", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeResetPassword,
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
- testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := url.Values{}
dat.Set("token", "MivFxYiSMMA4An9dP24DNQ==")
@@ -739,24 +719,24 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
- testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
- assert.Equal(t, a.Password, account.Password, "password should not have been updated")
+ testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
+ testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account")
+ assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
})
t.Run("used token", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
usedAt := time.Now().Add(time.Hour * -11).UTC()
tok := database.Token{
@@ -765,8 +745,8 @@ func TestResetPassword(t *testing.T) {
Type: database.TokenTypeResetPassword,
UsedAt: &usedAt,
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
- testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := url.Values{}
dat.Set("token", "MivFxYiSMMA4An9dP24DNQ==")
@@ -782,9 +762,9 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
- testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
- assert.Equal(t, a.Password, account.Password, "password should not have been updated")
+ testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
+ testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account")
+ assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
resetTokenUsedAtUTC := resetToken.UsedAt.UTC()
if resetTokenUsedAtUTC.Year() != usedAt.Year() ||
@@ -798,24 +778,24 @@ func TestResetPassword(t *testing.T) {
})
t.Run("using wrong type token: email_verification", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
tok := database.Token{
UserID: u.ID,
Value: "MivFxYiSMMA4An9dP24DNQ==",
Type: database.TokenTypeEmailVerification,
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "Failed to prepare reset_token")
- testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
+ testutils.MustExec(t, db.Save(&tok), "Failed to prepare reset_token")
+ testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at")
dat := url.Values{}
dat.Set("token", "MivFxYiSMMA4An9dP24DNQ==")
@@ -831,27 +811,27 @@ func TestResetPassword(t *testing.T) {
var resetToken database.Token
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
- testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account")
+ testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token")
+ testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account")
- assert.Equal(t, a.Password, account.Password, "password should not have been updated")
+ assert.Equal(t, acc.Password, account.Password, "password should not have been updated")
assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil")
})
}
func TestCreateResetToken(t *testing.T) {
t.Run("success", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ u := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
// Execute
dat := url.Values{}
@@ -864,10 +844,10 @@ func TestCreateResetToken(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismtach")
var tokenCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
var resetToken database.Token
- testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token")
+ testutils.MustExec(t, db.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token")
assert.Equal(t, tokenCount, int64(1), "reset_token count mismatch")
assert.NotEqual(t, resetToken.Value, nil, "reset_token value mismatch")
@@ -875,17 +855,17 @@ func TestCreateResetToken(t *testing.T) {
})
t.Run("nonexistent email", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- testutils.SetupAccountData(u, "alice@example.com", "somepassword")
+ u := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, u, "alice@example.com", "somepassword")
// Execute
dat := url.Values{}
@@ -898,24 +878,24 @@ func TestCreateResetToken(t *testing.T) {
assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach")
var tokenCount int64
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens")
assert.Equal(t, tokenCount, int64(0), "reset_token count mismatch")
})
}
func TestUpdatePassword(t *testing.T) {
t.Run("success", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@example.com", "oldpassword")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword")
// Execute
dat := url.Values{}
@@ -924,29 +904,29 @@ func TestUpdatePassword(t *testing.T) {
dat.Set("new_password_confirmation", "newpassword")
req := testutils.MakeFormReq(server.URL, "PATCH", "/account/password", dat)
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismsatch")
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword"))
assert.Equal(t, passwordErr, nil, "Password mismatch")
})
t.Run("old password mismatch", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
// Execute
dat := url.Values{}
@@ -955,28 +935,28 @@ func TestUpdatePassword(t *testing.T) {
dat.Set("new_password_confirmation", "newpassword")
req := testutils.MakeFormReq(server.URL, "PATCH", "/account/password", dat)
- res := testutils.HTTPAuthDo(t, req, u)
+ res := testutils.HTTPAuthDo(t, db, req, u)
// Test
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
- assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
+ testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
+ assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
})
t.Run("password too short", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
// Execute
dat := url.Values{}
@@ -985,28 +965,28 @@ func TestUpdatePassword(t *testing.T) {
dat.Set("new_password_confirmation", "a")
req := testutils.MakeFormReq(server.URL, "PATCH", "/account/password", dat)
- res := testutils.HTTPAuthDo(t, req, u)
+ res := testutils.HTTPAuthDo(t, db, req, u)
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
- assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
+ testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
+ assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
})
t.Run("password confirmation mismatch", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword")
// Execute
dat := url.Values{}
@@ -1015,32 +995,32 @@ func TestUpdatePassword(t *testing.T) {
dat.Set("new_password_confirmation", "newpassword2")
req := testutils.MakeFormReq(server.URL, "PATCH", "/account/password", dat)
- res := testutils.HTTPAuthDo(t, req, u)
+ res := testutils.HTTPAuthDo(t, db, req, u)
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch")
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
- assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated")
+ testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
+ assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated")
})
}
func TestUpdateEmail(t *testing.T) {
t.Run("success", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "pass1234")
- a.EmailVerified = true
- testutils.MustExec(t, testutils.DB.Save(&a), "updating email_verified")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
+ acc.EmailVerified = true
+ testutils.MustExec(t, db.Save(&acc), "updating email_verified")
// Execute
dat := url.Values{}
@@ -1048,34 +1028,34 @@ func TestUpdateEmail(t *testing.T) {
dat.Set("password", "pass1234")
req := testutils.MakeFormReq(server.URL, "PATCH", "/account/profile", dat)
- res := testutils.HTTPAuthDo(t, req, u)
+ res := testutils.HTTPAuthDo(t, db, req, u)
// Test
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch")
var user database.User
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
+ testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, account.Email.String, "alice-new@example.com", "email mismatch")
assert.Equal(t, account.EmailVerified, false, "EmailVerified mismatch")
})
t.Run("password mismatch", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- u := testutils.SetupUserData()
- a := testutils.SetupAccountData(u, "alice@example.com", "pass1234")
- a.EmailVerified = true
- testutils.MustExec(t, testutils.DB.Save(&a), "updating email_verified")
+ u := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, u, "alice@example.com", "pass1234")
+ acc.EmailVerified = true
+ testutils.MustExec(t, db.Save(&acc), "updating email_verified")
// Execute
dat := url.Values{}
@@ -1083,15 +1063,15 @@ func TestUpdateEmail(t *testing.T) {
dat.Set("password", "wrongpassword")
req := testutils.MakeFormReq(server.URL, "PATCH", "/account/profile", dat)
- res := testutils.HTTPAuthDo(t, req, u)
+ res := testutils.HTTPAuthDo(t, db, req, u)
// Test
assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch")
var user database.User
var account database.Account
- testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user")
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user")
+ testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account")
assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch")
assert.Equal(t, account.EmailVerified, true, "EmailVerified mismatch")
@@ -1100,27 +1080,27 @@ func TestUpdateEmail(t *testing.T) {
func TestVerifyEmail(t *testing.T) {
t.Run("success", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@example.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@example.com", "pass1234")
tok := database.Token{
UserID: user.ID,
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
// Execute
req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/verify-email/%s", "someTokenValue"), "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch")
@@ -1128,9 +1108,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int64
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified mismatch")
assert.NotEqual(t, token.Value, "", "token value should not have been updated")
@@ -1139,17 +1119,17 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("used token", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@example.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@example.com", "pass1234")
usedAt := time.Now().Add(time.Hour * -11).UTC()
tok := database.Token{
@@ -1158,11 +1138,11 @@ func TestVerifyEmail(t *testing.T) {
Value: "someTokenValue",
UsedAt: &usedAt,
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
// Execute
req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/verify-email/%s", "someTokenValue"), "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusBadRequest, "")
@@ -1170,9 +1150,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int64
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified mismatch")
assert.NotEqual(t, token.UsedAt, nil, "token used_at mismatch")
@@ -1181,29 +1161,29 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("expired token", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@example.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@example.com", "pass1234")
tok := database.Token{
UserID: user.ID,
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
- testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at")
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
+ testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at")
// Execute
req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/verify-email/%s", "someTokenValue"), "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusGone, "")
@@ -1211,9 +1191,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int64
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified mismatch")
assert.Equal(t, tokenCount, int64(1), "token count mismatch")
@@ -1221,30 +1201,30 @@ func TestVerifyEmail(t *testing.T) {
})
t.Run("already verified", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- a := testutils.SetupAccountData(user, "alice@example.com", "oldpass1234")
- a.EmailVerified = true
- testutils.MustExec(t, testutils.DB.Save(&a), "preparing account")
+ user := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, user, "alice@example.com", "oldpass1234")
+ acc.EmailVerified = true
+ testutils.MustExec(t, db.Save(&acc), "preparing account")
tok := database.Token{
UserID: user.ID,
Type: database.TokenTypeEmailVerification,
Value: "someTokenValue",
}
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
// Execute
req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/verify-email/%s", "someTokenValue"), "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusConflict, "")
@@ -1252,9 +1232,9 @@ func TestVerifyEmail(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int64
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified mismatch")
assert.Equal(t, tokenCount, int64(1), "token count mismatch")
@@ -1264,23 +1244,23 @@ func TestVerifyEmail(t *testing.T) {
func TestCreateVerificationToken(t *testing.T) {
t.Run("success", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
emailBackend := testutils.MockEmailbackendImplementation{}
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- EmailBackend: &emailBackend,
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ a.EmailBackend = &emailBackend
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@example.com", "pass1234")
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@example.com", "pass1234")
// Execute
req := testutils.MakeReq(server.URL, "POST", "/verification-token", "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusFound, "status code mismatch")
@@ -1288,9 +1268,9 @@ func TestCreateVerificationToken(t *testing.T) {
var account database.Account
var token database.Token
var tokenCount int64
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, false, "email_verified should not have been updated")
assert.NotEqual(t, token.Value, "", "token Value mismatch")
@@ -1300,30 +1280,30 @@ func TestCreateVerificationToken(t *testing.T) {
})
t.Run("already verified", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Setup
- server := MustNewServer(t, &app.App{
- Clock: clock.NewMock(),
- Config: config.Config{},
- })
+ a := app.NewTest()
+ a.Clock = clock.NewMock()
+ a.DB = db
+ server := MustNewServer(t, &a)
defer server.Close()
- user := testutils.SetupUserData()
- a := testutils.SetupAccountData(user, "alice@example.com", "pass1234")
- a.EmailVerified = true
- testutils.MustExec(t, testutils.DB.Save(&a), "preparing account")
+ user := testutils.SetupUserData(db)
+ acc := testutils.SetupAccountData(db, user, "alice@example.com", "pass1234")
+ acc.EmailVerified = true
+ testutils.MustExec(t, db.Save(&acc), "preparing account")
// Execute
req := testutils.MakeReq(server.URL, "POST", "/verification-token", "")
- res := testutils.HTTPAuthDo(t, req, user)
+ res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusConflict, "Status code mismatch")
var account database.Account
var tokenCount int64
- testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account")
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token")
+ testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token")
assert.Equal(t, account.EmailVerified, true, "email_verified should not have been updated")
assert.Equal(t, tokenCount, int64(0), "token count mismatch")
diff --git a/pkg/server/database/consts.go b/pkg/server/database/consts.go
index 38c4d536..b4a1db03 100644
--- a/pkg/server/database/consts.go
+++ b/pkg/server/database/consts.go
@@ -23,8 +23,6 @@ const (
TokenTypeResetPassword = "reset_password"
// TokenTypeEmailVerification is a type of a token for verifying email
TokenTypeEmailVerification = "email_verification"
- // TokenTypeEmailPreference is a type of a token for updating email preference
- TokenTypeEmailPreference = "email_preference"
)
const (
diff --git a/pkg/server/database/database.go b/pkg/server/database/database.go
index 3c5b6d9b..eaab7c50 100644
--- a/pkg/server/database/database.go
+++ b/pkg/server/database/database.go
@@ -19,9 +19,11 @@
package database
import (
- "github.com/dnote/dnote/pkg/server/config"
+ "os"
+ "path/filepath"
+
"github.com/pkg/errors"
- "gorm.io/driver/postgres"
+ "gorm.io/driver/sqlite"
"gorm.io/gorm"
)
@@ -32,18 +34,12 @@ var (
// InitSchema migrates database schema to reflect the latest model definition
func InitSchema(db *gorm.DB) {
- if err := db.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil {
- panic(err)
- }
-
if err := db.AutoMigrate(
&User{},
&Account{},
&Book{},
&Note{},
- &Notification{},
&Token{},
- &EmailPreference{},
&Session{},
); err != nil {
panic(err)
@@ -51,8 +47,14 @@ func InitSchema(db *gorm.DB) {
}
// Open initializes the database connection
-func Open(c config.Config) *gorm.DB {
- db, err := gorm.Open(postgres.Open(c.DB.GetConnectionStr()), &gorm.Config{})
+func Open(dbPath string) *gorm.DB {
+ // Create directory if it doesn't exist
+ dir := filepath.Dir(dbPath)
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ panic(errors.Wrapf(err, "creating database directory at %s", dir))
+ }
+
+ db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
panic(errors.Wrap(err, "opening database conection"))
}
diff --git a/pkg/server/database/migrate.go b/pkg/server/database/migrate.go
index 8b258b12..4250fd80 100644
--- a/pkg/server/database/migrate.go
+++ b/pkg/server/database/migrate.go
@@ -19,34 +19,167 @@
package database
import (
- "log"
- "net/http"
+ "fmt"
+ "io/fs"
+ "sort"
+ "strings"
"github.com/dnote/dnote/pkg/server/database/migrations"
- "gorm.io/gorm"
+ "github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
- "github.com/rubenv/sql-migrate"
+ "gorm.io/gorm"
)
-// Migrate runs the migrations
-func Migrate(db *gorm.DB) error {
- migrations := &migrate.HttpFileSystemMigrationSource{
- FileSystem: http.FileSystem(http.FS(migrations.Files)),
+type migrationFile struct {
+ filename string
+ version int
+}
+
+// validateMigrationFilename checks if filename follows format: NNN-description.sql
+func validateMigrationFilename(name string) error {
+ // Check .sql extension
+ if !strings.HasSuffix(name, ".sql") {
+ return errors.Errorf("invalid migration filename: must end with .sql")
}
- migrate.SetTable(MigrationTableName)
-
- sqlDB, err := db.DB()
- if err != nil {
- return errors.Wrap(err, "getting underlying sql.DB")
+ name = strings.TrimSuffix(name, ".sql")
+ parts := strings.SplitN(name, "-", 2)
+ if len(parts) != 2 {
+ return errors.Errorf("invalid migration filename: must be NNN-description.sql")
}
- n, err := migrate.Exec(sqlDB, "postgres", migrations, migrate.Up)
- if err != nil {
- return errors.Wrap(err, "running migrations")
+ version, description := parts[0], parts[1]
+
+ // Validate version is 3 digits
+ if len(version) != 3 {
+ return errors.Errorf("invalid migration filename: version must be 3 digits, got %s", version)
+ }
+ for _, c := range version {
+ if c < '0' || c > '9' {
+ return errors.Errorf("invalid migration filename: version must be numeric, got %s", version)
+ }
}
- log.Printf("Performed %d migrations", n)
+ // Validate description is not empty
+ if description == "" {
+ return errors.Errorf("invalid migration filename: description is required")
+ }
+
+ return nil
+}
+
+// Migrate runs the migrations using the embedded migration files
+func Migrate(db *gorm.DB) error {
+ return migrate(db, migrations.Files)
+}
+
+// getMigrationFiles reads, validates, and sorts migration files
+func getMigrationFiles(fsys fs.FS) ([]migrationFile, error) {
+ entries, err := fs.ReadDir(fsys, ".")
+ if err != nil {
+ return nil, errors.Wrap(err, "reading migration directory")
+ }
+
+ var migrations []migrationFile
+ seen := make(map[int]string)
+ for _, e := range entries {
+ name := e.Name()
+
+ if err := validateMigrationFilename(name); err != nil {
+ return nil, err
+ }
+
+ // Parse version
+ var v int
+ fmt.Sscanf(name, "%d", &v)
+
+ // Check for duplicate version numbers
+ if existing, found := seen[v]; found {
+ return nil, errors.Errorf("duplicate migration version %d: %s and %s", v, existing, name)
+ }
+ seen[v] = name
+
+ migrations = append(migrations, migrationFile{
+ filename: name,
+ version: v,
+ })
+ }
+
+ // Sort by version
+ sort.Slice(migrations, func(i, j int) bool {
+ return migrations[i].version < migrations[j].version
+ })
+
+ return migrations, nil
+}
+
+// migrate runs migrations from the provided filesystem
+func migrate(db *gorm.DB, fsys fs.FS) error {
+ if err := db.Exec(`
+ CREATE TABLE IF NOT EXISTS schema_migrations (
+ version INTEGER PRIMARY KEY,
+ applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
+ )
+ `).Error; err != nil {
+ return errors.Wrap(err, "initializing migration table")
+ }
+
+ // Get current version
+ var version int
+ if err := db.Raw("SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(&version).Error; err != nil {
+ return errors.Wrap(err, "reading current version")
+ }
+
+ // Read and validate migration files
+ migrations, err := getMigrationFiles(fsys)
+ if err != nil {
+ return err
+ }
+
+ var filenames []string
+ for _, m := range migrations {
+ filenames = append(filenames, m.filename)
+ }
+
+ log.WithFields(log.Fields{
+ "version": version,
+ }).Info("Database schema version.")
+
+ log.WithFields(log.Fields{
+ "files": filenames,
+ }).Debug("Database migration files.")
+
+ // Apply pending migrations
+ for _, m := range migrations {
+ if m.version <= version {
+ continue
+ }
+
+ log.WithFields(log.Fields{
+ "file": m.filename,
+ }).Info("Applying migration.")
+
+ sql, err := fs.ReadFile(fsys, m.filename)
+ if err != nil {
+ return errors.Wrapf(err, "reading migration file %s", m.filename)
+ }
+
+ if len(strings.TrimSpace(string(sql))) == 0 {
+ return errors.Errorf("migration file %s is empty", m.filename)
+ }
+
+ if err := db.Exec(string(sql)).Error; err != nil {
+ return fmt.Errorf("migration %s failed: %w", m.filename, err)
+ }
+
+ if err := db.Exec("INSERT INTO schema_migrations (version) VALUES (?)", m.version).Error; err != nil {
+ return errors.Wrapf(err, "recording migration %s", m.filename)
+ }
+
+ log.WithFields(log.Fields{
+ "file": m.filename,
+ }).Info("Migrate success.")
+ }
return nil
}
diff --git a/pkg/server/database/migrate/main.go b/pkg/server/database/migrate/main.go
deleted file mode 100644
index 932f4ec3..00000000
--- a/pkg/server/database/migrate/main.go
+++ /dev/null
@@ -1,72 +0,0 @@
-/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
- *
- * This file is part of Dnote.
- *
- * Dnote is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * Dnote is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with Dnote. If not, see .
- */
-
-package main
-
-import (
- "flag"
- "fmt"
- "os"
-
- "github.com/dnote/dnote/pkg/server/config"
- "github.com/dnote/dnote/pkg/server/database"
- "github.com/joho/godotenv"
- "github.com/pkg/errors"
- "github.com/rubenv/sql-migrate"
-)
-
-var (
- migrationDir = flag.String("migrationDir", "../migrations", "the path to the directory with migraiton files")
-)
-
-func init() {
- fmt.Println("Migrating Dnote database...")
-
- // Load env
- if os.Getenv("GO_ENV") != "PRODUCTION" {
- if err := godotenv.Load("../../.env.dev"); err != nil {
- panic(err)
- }
- }
-
-}
-
-func main() {
- flag.Parse()
-
- c := config.Load()
- db := database.Open(c)
-
- migrations := &migrate.FileMigrationSource{
- Dir: *migrationDir,
- }
-
- migrate.SetTable("migrations")
-
- sqlDB, err := db.DB()
- if err != nil {
- panic(errors.Wrap(err, "getting underlying sql.DB"))
- }
-
- n, err := migrate.Exec(sqlDB, "postgres", migrations, migrate.Up)
- if err != nil {
- panic(errors.Wrap(err, "executing migrations"))
- }
-
- fmt.Printf("Applied %d migrations\n", n)
-}
diff --git a/pkg/server/database/migrate_test.go b/pkg/server/database/migrate_test.go
new file mode 100644
index 00000000..31853730
--- /dev/null
+++ b/pkg/server/database/migrate_test.go
@@ -0,0 +1,313 @@
+/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
+ *
+ * This file is part of Dnote.
+ *
+ * Dnote is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Dnote is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Dnote. If not, see .
+ */
+
+package database
+
+import (
+ "io/fs"
+ "testing"
+ "testing/fstest"
+
+ "gorm.io/driver/sqlite"
+ "gorm.io/gorm"
+)
+
+// unsortedFS wraps fstest.MapFS to return entries in reverse order
+type unsortedFS struct {
+ fstest.MapFS
+}
+
+func (u unsortedFS) ReadDir(name string) ([]fs.DirEntry, error) {
+ entries, err := u.MapFS.ReadDir(name)
+ if err != nil {
+ return nil, err
+ }
+ // Reverse the entries to ensure they're not in sorted order
+ for i, j := 0, len(entries)-1; i < j; i, j = i+1, j-1 {
+ entries[i], entries[j] = entries[j], entries[i]
+ }
+ return entries, nil
+}
+
+// errorFS returns an error on ReadDir
+type errorFS struct{}
+
+func (e errorFS) Open(name string) (fs.File, error) {
+ return nil, fs.ErrNotExist
+}
+
+func (e errorFS) ReadDir(name string) ([]fs.DirEntry, error) {
+ return nil, fs.ErrPermission
+}
+
+func TestMigrate_createsSchemaTable(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ migrationsFs := fstest.MapFS{}
+ migrate(db, migrationsFs)
+
+ // Verify schema_migrations table exists
+ var count int64
+ if err := db.Raw("SELECT COUNT(*) FROM schema_migrations").Scan(&count).Error; err != nil {
+ t.Fatalf("schema_migrations table not found: %v", err)
+ }
+}
+
+func TestMigrate_idempotency(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ // Set up table before migration
+ if err := db.Exec("CREATE TABLE counter (value INTEGER)").Error; err != nil {
+ t.Fatalf("failed to create table: %v", err)
+ }
+
+ // Create migration that inserts a row
+ migrationsFs := fstest.MapFS{
+ "001-insert-data.sql": &fstest.MapFile{
+ Data: []byte("INSERT INTO counter (value) VALUES (100);"),
+ },
+ }
+
+ // Run migration first time
+ if err := migrate(db, migrationsFs); err != nil {
+ t.Fatalf("first migration failed: %v", err)
+ }
+ var count int64
+ if err := db.Raw("SELECT COUNT(*) FROM counter").Scan(&count).Error; err != nil {
+ t.Fatalf("failed to count rows: %v", err)
+ }
+ if count != 1 {
+ t.Errorf("expected 1 row, got %d", count)
+ }
+
+ // Run migration second time - it should not run the SQL again
+ if err := migrate(db, migrationsFs); err != nil {
+ t.Fatalf("second migration failed: %v", err)
+ }
+ if err := db.Raw("SELECT COUNT(*) FROM counter").Scan(&count).Error; err != nil {
+ t.Fatalf("failed to count rows: %v", err)
+ }
+ if count != 1 {
+ t.Errorf("migration ran twice: expected 1 row, got %d", count)
+ }
+}
+
+func TestMigrate_ordering(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ // Create table before migrations
+ if err := db.Exec("CREATE TABLE log (value INTEGER)").Error; err != nil {
+ t.Fatalf("failed to create table: %v", err)
+ }
+
+ // Create migrations with unsorted filesystem
+ migrationsFs := unsortedFS{
+ MapFS: fstest.MapFS{
+ "010-tenth.sql": &fstest.MapFile{
+ Data: []byte("INSERT INTO log (value) VALUES (3);"),
+ },
+ "001-first.sql": &fstest.MapFile{
+ Data: []byte("INSERT INTO log (value) VALUES (1);"),
+ },
+ "002-second.sql": &fstest.MapFile{
+ Data: []byte("INSERT INTO log (value) VALUES (2);"),
+ },
+ },
+ }
+
+ // Run migrations
+ if err := migrate(db, migrationsFs); err != nil {
+ t.Fatalf("migration failed: %v", err)
+ }
+
+ // Verify migrations ran in correct order (1, 2, 3)
+ var values []int
+ if err := db.Raw("SELECT value FROM log ORDER BY rowid").Scan(&values).Error; err != nil {
+ t.Fatalf("failed to query log: %v", err)
+ }
+
+ expected := []int{1, 2, 3}
+ if len(values) != len(expected) {
+ t.Fatalf("expected %d rows, got %d", len(expected), len(values))
+ }
+
+ for i, v := range values {
+ if v != expected[i] {
+ t.Errorf("row %d: expected value %d, got %d", i, expected[i], v)
+ }
+ }
+}
+
+func TestMigrate_duplicateVersion(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ // Create migrations with duplicate version numbers
+ migrationsFs := fstest.MapFS{
+ "001-first.sql": &fstest.MapFile{
+ Data: []byte("SELECT 1;"),
+ },
+ "001-second.sql": &fstest.MapFile{
+ Data: []byte("SELECT 2;"),
+ },
+ }
+
+ // Should return error for duplicate version
+ err = migrate(db, migrationsFs)
+ if err == nil {
+ t.Fatal("expected error for duplicate version numbers, got nil")
+ }
+}
+
+func TestMigrate_initTableError(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ // Close the database connection to cause exec to fail
+ sqlDB, _ := db.DB()
+ sqlDB.Close()
+
+ migrationsFs := fstest.MapFS{
+ "001-init.sql": &fstest.MapFile{
+ Data: []byte("SELECT 1;"),
+ },
+ }
+
+ // Should return error for table initialization failure
+ err = migrate(db, migrationsFs)
+ if err == nil {
+ t.Fatal("expected error for table initialization failure, got nil")
+ }
+}
+
+func TestMigrate_readDirError(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ // Use filesystem that fails on ReadDir
+ err = migrate(db, errorFS{})
+ if err == nil {
+ t.Fatal("expected error for ReadDir failure, got nil")
+ }
+}
+
+func TestMigrate_sqlError(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ // Create migration with invalid SQL
+ migrationsFs := fstest.MapFS{
+ "001-bad-sql.sql": &fstest.MapFile{
+ Data: []byte("INVALID SQL SYNTAX HERE;"),
+ },
+ }
+
+ // Should return error for SQL execution failure
+ err = migrate(db, migrationsFs)
+ if err == nil {
+ t.Fatal("expected error for invalid SQL, got nil")
+ }
+}
+
+func TestMigrate_emptyFile(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ data string
+ wantErr bool
+ }{
+ {"completely empty", "", true},
+ {"only whitespace", " \n\t ", true},
+ {"only comments", "-- just a comment", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ migrationsFs := fstest.MapFS{
+ "001-empty.sql": &fstest.MapFile{
+ Data: []byte(tt.data),
+ },
+ }
+
+ err = migrate(db, migrationsFs)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestMigrate_invalidFilename(t *testing.T) {
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open database: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ filename string
+ wantErr bool
+ }{
+ {"valid format", "001-init.sql", false},
+ {"no leading zeros", "1-init.sql", true},
+ {"two digits", "01-init.sql", true},
+ {"no dash", "001init.sql", true},
+ {"no description", "001-.sql", true},
+ {"no extension", "001-init.", true},
+ {"wrong extension", "001-init.txt", true},
+ {"non-numeric version number", "0a1-init.sql", true},
+ {"underscore separator", "001_init.sql", true},
+ {"multiple dashes in description", "001-add-feature-v2.sql", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ migrationsFs := fstest.MapFS{
+ tt.filename: &fstest.MapFile{
+ Data: []byte("SELECT 1;"),
+ },
+ }
+
+ err := migrate(db, migrationsFs)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
diff --git a/pkg/server/database/migrations/.gitkeep b/pkg/server/database/migrations/.gitkeep
deleted file mode 100644
index e69de29b..00000000
diff --git a/pkg/server/database/migrations/100-create-fts-table.sql b/pkg/server/database/migrations/100-create-fts-table.sql
new file mode 100644
index 00000000..21b43704
--- /dev/null
+++ b/pkg/server/database/migrations/100-create-fts-table.sql
@@ -0,0 +1,18 @@
+-- Create FTS5 virtual table for full-text search on notes
+CREATE VIRTUAL TABLE IF NOT EXISTS notes_fts USING fts5(
+ content=notes,
+ body,
+ tokenize="porter unicode61 categories 'L* N* Co Ps Pe'"
+);
+
+-- Create triggers to keep notes_fts in sync with notes
+CREATE TRIGGER IF NOT EXISTS notes_insert AFTER INSERT ON notes BEGIN
+ INSERT INTO notes_fts(rowid, body) VALUES (new.rowid, new.body);
+END;
+CREATE TRIGGER IF NOT EXISTS notes_delete AFTER DELETE ON notes BEGIN
+ INSERT INTO notes_fts(notes_fts, rowid, body) VALUES ('delete', old.rowid, old.body);
+END;
+CREATE TRIGGER IF NOT EXISTS notes_update AFTER UPDATE ON notes BEGIN
+ INSERT INTO notes_fts(notes_fts, rowid, body) VALUES ('delete', old.rowid, old.body);
+ INSERT INTO notes_fts(rowid, body) VALUES (new.rowid, new.body);
+END;
\ No newline at end of file
diff --git a/pkg/server/database/migrations/20190819115834-full-text-search.sql b/pkg/server/database/migrations/20190819115834-full-text-search.sql
deleted file mode 100644
index b3d884e9..00000000
--- a/pkg/server/database/migrations/20190819115834-full-text-search.sql
+++ /dev/null
@@ -1,41 +0,0 @@
-
--- +migrate Up
-
--- Configure full text search
-CREATE TEXT SEARCH DICTIONARY english_nostop (
- Template = snowball,
- Language = english
-);
-
-CREATE TEXT SEARCH CONFIGURATION public.english_nostop ( COPY = pg_catalog.english );
-
-ALTER TEXT SEARCH CONFIGURATION public.english_nostop
-ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH english_nostop;
-
-
--- Create a trigger
--- +migrate StatementBegin
-CREATE OR REPLACE FUNCTION note_tsv_trigger() RETURNS trigger AS $$
-begin
- new.tsv := setweight(to_tsvector('english_nostop', new.body), 'A');
- return new;
-end
-$$ LANGUAGE plpgsql;
-
-DROP TRIGGER IF EXISTS tsvectorupdate ON notes;
-CREATE TRIGGER tsvectorupdate
-BEFORE INSERT OR UPDATE ON notes
-FOR EACH ROW EXECUTE PROCEDURE note_tsv_trigger();
--- +migrate StatementEnd
-
--- index tsv
-CREATE INDEX IF NOT EXISTS idx_notes_tsv
-ON notes
-USING gin(tsv);
-
--- initialize tsv
-UPDATE notes
-SET tsv = setweight(to_tsvector('english_nostop', notes.body), 'A')
-WHERE notes.encrypted = false;
-
--- +migrate Down
diff --git a/pkg/server/database/migrations/20191028103522-create-weekly-repetition.sql b/pkg/server/database/migrations/20191028103522-create-weekly-repetition.sql
deleted file mode 100644
index 22ac6f3a..00000000
--- a/pkg/server/database/migrations/20191028103522-create-weekly-repetition.sql
+++ /dev/null
@@ -1,8 +0,0 @@
--- this migration is noop because repetition_rules have been removed
-
--- create-weekly-repetition.sql creates the default repetition rules for the users
--- that used to have the weekly email digest on Friday 20:00 UTC
-
--- +migrate Up
-
--- +migrate Down
diff --git a/pkg/server/database/migrations/20191225185502-populate-digest-version.sql b/pkg/server/database/migrations/20191225185502-populate-digest-version.sql
deleted file mode 100644
index 73098dbf..00000000
--- a/pkg/server/database/migrations/20191225185502-populate-digest-version.sql
+++ /dev/null
@@ -1,9 +0,0 @@
--- this migration is noop because digests have been removed
-
--- populate-digest-version.sql populates the `version` column for the digests
--- by assigining an incremental integer scoped to a repetition rule that each
--- digest belongs, ordered by created_at timestamp of the digests.
-
--- +migrate Up
-
--- +migrate Down
diff --git a/pkg/server/database/migrations/20191226093447-add-digest-id-primary-key.sql b/pkg/server/database/migrations/20191226093447-add-digest-id-primary-key.sql
deleted file mode 100644
index c7d17d2e..00000000
--- a/pkg/server/database/migrations/20191226093447-add-digest-id-primary-key.sql
+++ /dev/null
@@ -1,5 +0,0 @@
--- this migration is noop because digests have been removed
-
--- +migrate Up
-
--- +migrate Down
diff --git a/pkg/server/database/migrations/20191226105659-use-id-in-digest-notes-joining-table.sql b/pkg/server/database/migrations/20191226105659-use-id-in-digest-notes-joining-table.sql
deleted file mode 100644
index faec52ff..00000000
--- a/pkg/server/database/migrations/20191226105659-use-id-in-digest-notes-joining-table.sql
+++ /dev/null
@@ -1,8 +0,0 @@
--- this migration is noop because digests have been removed
-
--- -use-id-in-digest-notes-joining-table.sql replaces uuids with ids
--- as foreign keys in the digest_notes joining table.
-
--- +migrate Up
-
--- +migrate Down
diff --git a/pkg/server/database/migrations/20191226152111-delete-outdated-digests.sql b/pkg/server/database/migrations/20191226152111-delete-outdated-digests.sql
deleted file mode 100644
index 84c8ccf3..00000000
--- a/pkg/server/database/migrations/20191226152111-delete-outdated-digests.sql
+++ /dev/null
@@ -1,8 +0,0 @@
--- this migration is noop because digests have been removed
-
--- delete-outdated-digests.sql deletes digests that do not belong to any repetition rules,
--- along with digest_notes associations.
-
--- +migrate Up
-
--- +migrate Down
diff --git a/pkg/server/database/migrations/20200522170529-remove-billing-columns.sql b/pkg/server/database/migrations/20200522170529-remove-billing-columns.sql
deleted file mode 100644
index a814f26b..00000000
--- a/pkg/server/database/migrations/20200522170529-remove-billing-columns.sql
+++ /dev/null
@@ -1,9 +0,0 @@
--- remove-billing-columns.sql drops billing related columns that are now obsolete.
-
--- +migrate Up
-
-ALTER TABLE users DROP COLUMN IF EXISTS stripe_customer_id;
-ALTER TABLE users DROP COLUMN IF EXISTS billing_country;
-
--- +migrate Down
-
diff --git a/pkg/server/database/models.go b/pkg/server/database/models.go
index 2126c106..ac3e4e56 100644
--- a/pkg/server/database/models.go
+++ b/pkg/server/database/models.go
@@ -25,14 +25,14 @@ import (
// Model is the base model definition
type Model struct {
ID int `gorm:"primaryKey" json:"-"`
- CreatedAt time.Time `json:"created_at" gorm:"default:now()"`
- UpdatedAt time.Time `json:"updated_at"`
+ CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
+ UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// Book is a model for a book
type Book struct {
Model
- UUID string `json:"uuid" gorm:"uniqueIndex;type:uuid;default:uuid_generate_v4()"`
+ UUID string `json:"uuid" gorm:"uniqueIndex;type:text"`
UserID int `json:"user_id" gorm:"index"`
Label string `json:"label" gorm:"index"`
Notes []Note `json:"notes" gorm:"foreignKey:BookUUID;references:UUID"`
@@ -46,15 +46,14 @@ type Book struct {
// Note is a model for a note
type Note struct {
Model
- UUID string `json:"uuid" gorm:"index;type:uuid;default:uuid_generate_v4()"`
+ UUID string `json:"uuid" gorm:"index;type:text"`
Book Book `json:"book" gorm:"foreignKey:BookUUID;references:UUID"`
User User `json:"user"`
UserID int `json:"user_id" gorm:"index"`
- BookUUID string `json:"book_uuid" gorm:"index;type:uuid"`
+ BookUUID string `json:"book_uuid" gorm:"index;type:text"`
Body string `json:"content"`
AddedOn int64 `json:"added_on"`
EditedOn int64 `json:"edited_on"`
- TSV string `json:"-" gorm:"type:tsvector"`
Public bool `json:"public" gorm:"default:false"`
USN int `json:"-" gorm:"index"`
Deleted bool `json:"-" gorm:"default:false"`
@@ -65,11 +64,10 @@ type Note struct {
// User is a model for a user
type User struct {
Model
- UUID string `json:"uuid" gorm:"type:uuid;index;default:uuid_generate_v4()"`
+ UUID string `json:"uuid" gorm:"type:text;index"`
Account Account `gorm:"foreignKey:UserID"`
LastLoginAt *time.Time `json:"-"`
MaxUSN int `json:"-" gorm:"default:0"`
- Cloud bool `json:"-" gorm:"default:false"`
}
// Account is a model for an account
@@ -90,21 +88,6 @@ type Token struct {
UsedAt *time.Time
}
-// Notification is the learning notification sent to the user
-type Notification struct {
- Model
- Type string
- UserID int `gorm:"index"`
-}
-
-// EmailPreference is a preference per user for receiving email communication
-type EmailPreference struct {
- Model
- UserID int `gorm:"index" json:"-"`
- InactiveReminder bool `json:"inactive_reminder" gorm:"default:false"`
- ProductUpdate bool `json:"product_update" gorm:"default:true"`
-}
-
// Session represents a user session
type Session struct {
Model
diff --git a/pkg/server/database/scripts/create-migration.sh b/pkg/server/database/scripts/create-migration.sh
deleted file mode 100755
index ef168165..00000000
--- a/pkg/server/database/scripts/create-migration.sh
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/usr/bin/env bash
-# create-migration.sh creates a new SQL migration file for the
-# server side Postgres database using the sql-migrate tool.
-set -eux
-
-is_command () {
- command -v "$1" >/dev/null 2>&1;
-}
-
-if ! is_command sql-migrate; then
- echo "sql-migrate is not found. Please run install-sql-migrate.sh"
- exit 1
-fi
-
-if [ "$#" == 0 ]; then
- echo "filename not provided"
- exit 1
-fi
-
-filename=$1
-sql-migrate new -config=./sql-migrate.yml "$filename"
diff --git a/pkg/server/database/scripts/install-sql-migrate.sh b/pkg/server/database/scripts/install-sql-migrate.sh
deleted file mode 100755
index 334fb817..00000000
--- a/pkg/server/database/scripts/install-sql-migrate.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-#!/usr/bin/env bash
-
-go get -v github.com/rubenv/sql-migrate/...
diff --git a/pkg/server/database/sql-migrate.yml b/pkg/server/database/sql-migrate.yml
deleted file mode 100644
index f9c90d83..00000000
--- a/pkg/server/database/sql-migrate.yml
+++ /dev/null
@@ -1,8 +0,0 @@
-# A configuration for sql-migrate tool for generating migrations
-# using `sql-migrate new`. This file is not actually used for running
-# migrations because we run them programmatically.
-
-development:
- dialect: postgres
- datasource: dbname=dnote sslmode=disable
- dir: ./migrations
diff --git a/pkg/server/job/job.go b/pkg/server/job/job.go
deleted file mode 100644
index 12efd823..00000000
--- a/pkg/server/job/job.go
+++ /dev/null
@@ -1,127 +0,0 @@
-/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
- *
- * This file is part of Dnote.
- *
- * Dnote is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * Dnote is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with Dnote. If not, see .
- */
-
-package job
-
-import (
- slog "log"
-
- "github.com/dnote/dnote/pkg/clock"
- "github.com/dnote/dnote/pkg/server/config"
- "github.com/dnote/dnote/pkg/server/mailer"
- "gorm.io/gorm"
- "github.com/pkg/errors"
- "github.com/robfig/cron"
-)
-
-var (
- // ErrEmptyDB is an error for missing database connection in the app configuration
- ErrEmptyDB = errors.New("No database connection was provided")
- // ErrEmptyClock is an error for missing clock in the app configuration
- ErrEmptyClock = errors.New("No clock was provided")
- // ErrEmptyWebURL is an error for missing WebURL content in the app configuration
- ErrEmptyWebURL = errors.New("No WebURL was provided")
- // ErrEmptyEmailTemplates is an error for missing EmailTemplates content in the app configuration
- ErrEmptyEmailTemplates = errors.New("No EmailTemplate store was provided")
- // ErrEmptyEmailBackend is an error for missing EmailBackend content in the app configuration
- ErrEmptyEmailBackend = errors.New("No EmailBackend was provided")
-)
-
-// Runner is a configuration for job
-type Runner struct {
- DB *gorm.DB
- Clock clock.Clock
- EmailTmpl mailer.Templates
- EmailBackend mailer.Backend
- Config config.Config
-}
-
-// NewRunner returns a new runner
-func NewRunner(db *gorm.DB, c clock.Clock, t mailer.Templates, b mailer.Backend, config config.Config) (Runner, error) {
- ret := Runner{
- DB: db,
- EmailTmpl: t,
- EmailBackend: b,
- Clock: c,
- Config: config,
- }
-
- if err := ret.validate(); err != nil {
- return Runner{}, errors.Wrap(err, "validating runner configuration")
- }
-
- return ret, nil
-}
-
-func (r *Runner) validate() error {
- if r.DB == nil {
- return ErrEmptyDB
- }
- if r.Clock == nil {
- return ErrEmptyClock
- }
- if r.EmailTmpl == nil {
- return ErrEmptyEmailTemplates
- }
- if r.EmailBackend == nil {
- return ErrEmptyEmailBackend
- }
- if r.Config.WebURL == "" {
- return ErrEmptyWebURL
- }
-
- return nil
-}
-
-func scheduleJob(c *cron.Cron, spec string, cmd func()) {
- s, err := cron.ParseStandard(spec)
- if err != nil {
- panic(errors.Wrap(err, "parsing schedule"))
- }
-
- c.Schedule(s, cron.FuncJob(cmd))
-}
-
-func (r *Runner) schedule(ch chan error) {
- // Schedule jobs
- cr := cron.New()
- cr.Start()
-
- ch <- nil
-
- // Block forever
- select {}
-}
-
-// Do starts the background tasks in a separate goroutine that runs forever
-func (r *Runner) Do() error {
- // validate
- if err := r.validate(); err != nil {
- return errors.Wrap(err, "validating job configurations")
- }
-
- ch := make(chan error)
- go r.schedule(ch)
- if err := <-ch; err != nil {
- return errors.Wrap(err, "scheduling jobs")
- }
-
- slog.Println("Started background tasks")
-
- return nil
-}
diff --git a/pkg/server/job/job_test.go b/pkg/server/job/job_test.go
deleted file mode 100644
index 885d429f..00000000
--- a/pkg/server/job/job_test.go
+++ /dev/null
@@ -1,104 +0,0 @@
-/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
- *
- * This file is part of Dnote.
- *
- * Dnote is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * Dnote is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with Dnote. If not, see .
- */
-
-package job
-
-import (
- "fmt"
- "testing"
-
- "github.com/dnote/dnote/pkg/assert"
- "github.com/dnote/dnote/pkg/clock"
- "github.com/dnote/dnote/pkg/server/config"
- "github.com/dnote/dnote/pkg/server/mailer"
- "github.com/dnote/dnote/pkg/server/testutils"
- "gorm.io/gorm"
- "github.com/pkg/errors"
-)
-
-func TestNewRunner(t *testing.T) {
- testCases := []struct {
- db *gorm.DB
- clock clock.Clock
- emailTmpl mailer.Templates
- emailBackend mailer.Backend
- webURL string
- expectedErr error
- }{
- {
- db: &gorm.DB{},
- clock: clock.NewMock(),
- emailTmpl: mailer.Templates{},
- emailBackend: &testutils.MockEmailbackendImplementation{},
- webURL: "http://mock.url",
- expectedErr: nil,
- },
- {
- db: nil,
- clock: clock.NewMock(),
- emailTmpl: mailer.Templates{},
- emailBackend: &testutils.MockEmailbackendImplementation{},
- webURL: "http://mock.url",
- expectedErr: ErrEmptyDB,
- },
- {
- db: &gorm.DB{},
- clock: nil,
- emailTmpl: mailer.Templates{},
- emailBackend: &testutils.MockEmailbackendImplementation{},
- webURL: "http://mock.url",
- expectedErr: ErrEmptyClock,
- },
- {
- db: &gorm.DB{},
- clock: clock.NewMock(),
- emailTmpl: nil,
- emailBackend: &testutils.MockEmailbackendImplementation{},
- webURL: "http://mock.url",
- expectedErr: ErrEmptyEmailTemplates,
- },
- {
- db: &gorm.DB{},
- clock: clock.NewMock(),
- emailTmpl: mailer.Templates{},
- emailBackend: nil,
- webURL: "http://mock.url",
- expectedErr: ErrEmptyEmailBackend,
- },
- {
- db: &gorm.DB{},
- clock: clock.NewMock(),
- emailTmpl: mailer.Templates{},
- emailBackend: &testutils.MockEmailbackendImplementation{},
- webURL: "",
- expectedErr: ErrEmptyWebURL,
- },
- }
-
- for idx, tc := range testCases {
- t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
-
- c := config.Load()
- c.WebURL = tc.webURL
-
- _, err := NewRunner(tc.db, tc.clock, tc.emailTmpl, tc.emailBackend, c)
-
- assert.Equal(t, errors.Cause(err), tc.expectedErr, "error mismatch")
- })
- }
-}
diff --git a/pkg/server/log/log.go b/pkg/server/log/log.go
index 78f737bd..89eeb719 100644
--- a/pkg/server/log/log.go
+++ b/pkg/server/log/log.go
@@ -32,9 +32,19 @@ const (
fieldKeyTimestamp = "ts"
fieldKeyUnixTimestamp = "ts_unix"
- levelInfo = "info"
- levelWarn = "warn"
- levelError = "error"
+ // LevelDebug represents debug log level
+ LevelDebug = "debug"
+ // LevelInfo represents info log level
+ LevelInfo = "info"
+ // LevelWarn represents warn log level
+ LevelWarn = "warn"
+ // LevelError represents error log level
+ LevelError = "error"
+)
+
+var (
+ // currentLevel is the currently configured log level
+ currentLevel = LevelInfo
)
// Fields represents a set of information to be included in the log
@@ -58,19 +68,50 @@ func WithFields(fields Fields) Entry {
return newEntry(fields)
}
+// SetLevel sets the global log level
+func SetLevel(level string) {
+ currentLevel = level
+}
+
+// levelPriority returns a numeric priority for comparison
+func levelPriority(level string) int {
+ switch level {
+ case LevelDebug:
+ return 0
+ case LevelInfo:
+ return 1
+ case LevelWarn:
+ return 2
+ case LevelError:
+ return 3
+ default:
+ return 1
+ }
+}
+
+// shouldLog returns true if the given level should be logged based on currentLevel
+func shouldLog(level string) bool {
+ return levelPriority(level) >= levelPriority(currentLevel)
+}
+
+// Debug logs the given entry at a debug level
+func (e Entry) Debug(msg string) {
+ e.write(LevelDebug, msg)
+}
+
// Info logs the given entry at an info level
func (e Entry) Info(msg string) {
- e.write(levelInfo, msg)
+ e.write(LevelInfo, msg)
}
// Warn logs the given entry at a warning level
func (e Entry) Warn(msg string) {
- e.write(levelWarn, msg)
+ e.write(LevelWarn, msg)
}
// Error logs the given entry at an error level
func (e Entry) Error(msg string) {
- e.write(levelError, msg)
+ e.write(LevelError, msg)
}
// ErrorWrap logs the given entry with the error message annotated by the given message
@@ -106,6 +147,10 @@ func (e Entry) formatJSON(level, msg string) []byte {
}
func (e Entry) write(level, msg string) {
+ if !shouldLog(level) {
+ return
+ }
+
serialized := e.formatJSON(level, msg)
_, err := fmt.Fprintln(os.Stderr, string(serialized))
@@ -114,6 +159,11 @@ func (e Entry) write(level, msg string) {
}
}
+// Debug logs a debug message without additional fields
+func Debug(msg string) {
+ newEntry(Fields{}).Debug(msg)
+}
+
// Info logs an info message without additional fields
func Info(msg string) {
newEntry(Fields{}).Info(msg)
diff --git a/pkg/server/mailer/backend.go b/pkg/server/mailer/backend.go
index eb6c3893..0abe51d8 100644
--- a/pkg/server/mailer/backend.go
+++ b/pkg/server/mailer/backend.go
@@ -19,11 +19,10 @@
package mailer
import (
- "fmt"
- "log"
"os"
"strconv"
+ "github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
"gopkg.in/gomail.v2"
)
@@ -36,9 +35,21 @@ type Backend interface {
Queue(subject, from string, to []string, contentType, body string) error
}
-// SimpleBackendImplementation is an implementation of the Backend
+// EmailDialer is an interface for sending email messages
+type EmailDialer interface {
+ DialAndSend(m ...*gomail.Message) error
+}
+
+// gomailDialer wraps gomail.Dialer to implement EmailDialer interface
+type gomailDialer struct {
+ *gomail.Dialer
+}
+
+// DefaultBackend is an implementation of the Backend
// that sends an email without queueing.
-type SimpleBackendImplementation struct {
+type DefaultBackend struct {
+ Dialer EmailDialer
+ Enabled bool
}
type dialerParams struct {
@@ -73,13 +84,31 @@ func getSMTPParams() (*dialerParams, error) {
return p, nil
}
+// NewDefaultBackend creates a default backend
+func NewDefaultBackend(enabled bool) (*DefaultBackend, error) {
+ p, err := getSMTPParams()
+ if err != nil {
+ return nil, err
+ }
+
+ d := gomail.NewDialer(p.Host, p.Port, p.Username, p.Password)
+
+ return &DefaultBackend{
+ Dialer: &gomailDialer{Dialer: d},
+ Enabled: enabled,
+ }, nil
+}
+
// Queue is an implementation of Backend.Queue.
-func (b *SimpleBackendImplementation) Queue(subject, from string, to []string, contentType, body string) error {
- // If not production, never actually send an email
- if os.Getenv("GO_ENV") != "PRODUCTION" {
- log.Println("Not sending email because Dnote is not running in a production environment.")
- log.Printf("Subject: %s, to: %s, from: %s", subject, to, from)
- fmt.Println(body)
+func (b *DefaultBackend) Queue(subject, from string, to []string, contentType, body string) error {
+ // If not enabled, just log the email
+ if !b.Enabled {
+ log.WithFields(log.Fields{
+ "subject": subject,
+ "to": to,
+ "from": from,
+ "body": body,
+ }).Info("Not sending email because email backend is not configured.")
return nil
}
@@ -89,13 +118,7 @@ func (b *SimpleBackendImplementation) Queue(subject, from string, to []string, c
m.SetHeader("Subject", subject)
m.SetBody(contentType, body)
- p, err := getSMTPParams()
- if err != nil {
- return errors.Wrap(err, "getting dialer params")
- }
-
- d := gomail.NewPlainDialer(p.Host, p.Port, p.Username, p.Password)
- if err := d.DialAndSend(m); err != nil {
+ if err := b.Dialer.DialAndSend(m); err != nil {
return errors.Wrap(err, "dialing and sending email")
}
diff --git a/pkg/server/mailer/backend_test.go b/pkg/server/mailer/backend_test.go
new file mode 100644
index 00000000..5ef0a355
--- /dev/null
+++ b/pkg/server/mailer/backend_test.go
@@ -0,0 +1,107 @@
+/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
+ *
+ * This file is part of Dnote.
+ *
+ * Dnote is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Dnote is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Dnote. If not, see .
+ */
+
+package mailer
+
+import (
+ "testing"
+
+ "gopkg.in/gomail.v2"
+)
+
+type mockDialer struct {
+ sentMessages []*gomail.Message
+ err error
+}
+
+func (m *mockDialer) DialAndSend(msgs ...*gomail.Message) error {
+ m.sentMessages = append(m.sentMessages, msgs...)
+ return m.err
+}
+
+func TestDefaultBackendQueue(t *testing.T) {
+ t.Run("enabled sends email", func(t *testing.T) {
+ mock := &mockDialer{}
+ backend := &DefaultBackend{
+ Dialer: mock,
+ Enabled: true,
+ }
+
+ err := backend.Queue("Test Subject", "alice@example.com", []string{"bob@example.com"}, "text/plain", "Test body")
+ if err != nil {
+ t.Fatalf("Queue failed: %v", err)
+ }
+
+ if len(mock.sentMessages) != 1 {
+ t.Errorf("expected 1 message sent, got %d", len(mock.sentMessages))
+ }
+ })
+
+ t.Run("disabled does not send email", func(t *testing.T) {
+ mock := &mockDialer{}
+ backend := &DefaultBackend{
+ Dialer: mock,
+ Enabled: false,
+ }
+
+ err := backend.Queue("Test Subject", "alice@example.com", []string{"bob@example.com"}, "text/plain", "Test body")
+ if err != nil {
+ t.Fatalf("Queue failed: %v", err)
+ }
+
+ if len(mock.sentMessages) != 0 {
+ t.Errorf("expected 0 messages sent when disabled, got %d", len(mock.sentMessages))
+ }
+ })
+}
+
+func TestNewDefaultBackend(t *testing.T) {
+ t.Run("with all env vars set", func(t *testing.T) {
+ t.Setenv("SmtpHost", "smtp.example.com")
+ t.Setenv("SmtpPort", "587")
+ t.Setenv("SmtpUsername", "user@example.com")
+ t.Setenv("SmtpPassword", "secret")
+
+ backend, err := NewDefaultBackend(true)
+ if err != nil {
+ t.Fatalf("NewDefaultBackend failed: %v", err)
+ }
+
+ if backend.Enabled != true {
+ t.Errorf("expected Enabled to be true, got %v", backend.Enabled)
+ }
+ if backend.Dialer == nil {
+ t.Error("expected Dialer to be set")
+ }
+ })
+
+ t.Run("missing SMTP config returns error", func(t *testing.T) {
+ t.Setenv("SmtpHost", "")
+ t.Setenv("SmtpPort", "")
+ t.Setenv("SmtpUsername", "")
+ t.Setenv("SmtpPassword", "")
+
+ _, err := NewDefaultBackend(true)
+ if err == nil {
+ t.Error("expected error when SMTP not configured")
+ }
+ if err != ErrSMTPNotConfigured {
+ t.Errorf("expected ErrSMTPNotConfigured, got %v", err)
+ }
+ })
+}
diff --git a/pkg/server/mailer/mailer.go b/pkg/server/mailer/mailer.go
index 786d703f..d02d8911 100644
--- a/pkg/server/mailer/mailer.go
+++ b/pkg/server/mailer/mailer.go
@@ -21,13 +21,11 @@ package mailer
import (
"bytes"
- "embed"
"fmt"
- htemplate "html/template"
"io"
ttemplate "text/template"
- "github.com/aymerick/douceur/inliner"
+ "github.com/dnote/dnote/pkg/server/mailer/templates"
"github.com/pkg/errors"
)
@@ -40,13 +38,9 @@ var (
EmailTypeEmailVerification = "verify_email"
// EmailTypeWelcome represents an welcome email
EmailTypeWelcome = "welcome"
- // EmailTypeInactiveReminder represents an inactivity reminder email
- EmailTypeInactiveReminder = "inactive"
)
var (
- // EmailKindHTML is the type of html email
- EmailKindHTML = "text/html"
// EmailKindText is the type of text email
EmailKindText = "text/plain"
)
@@ -60,9 +54,6 @@ type template interface {
// Templates holds the parsed email templates
type Templates map[string]template
-//go:embed templates/src
-var templateDir embed.FS
-
func getTemplateKey(name, kind string) string {
return fmt.Sprintf("%s.%s", name, kind)
}
@@ -100,58 +91,21 @@ func NewTemplates() Templates {
if err != nil {
panic(errors.Wrap(err, "initializing password reset template"))
}
- inactiveReminderText, err := initTextTmpl(EmailTypeInactiveReminder)
- if err != nil {
- panic(errors.Wrap(err, "initializing password reset template"))
- }
T := Templates{}
T.set(EmailTypeResetPassword, EmailKindText, passwordResetText)
T.set(EmailTypeResetPasswordAlert, EmailKindText, passwordResetAlertText)
T.set(EmailTypeEmailVerification, EmailKindText, verifyEmailText)
T.set(EmailTypeWelcome, EmailKindText, welcomeText)
- T.set(EmailTypeInactiveReminder, EmailKindText, inactiveReminderText)
return T
}
-// initHTMLTmpl returns a template instance by parsing the template with the
-// given name along with partials
-func initHTMLTmpl(templateName string) (template, error) {
- filename := fmt.Sprintf("templates/src/%s.html", templateName)
-
- content, err := templateDir.ReadFile(filename)
- if err != nil {
- return nil, errors.Wrap(err, "reading template")
- }
- headerContent, err := templateDir.ReadFile("templates/header.html")
- if err != nil {
- return nil, errors.Wrap(err, "reading header template")
- }
- footerContent, err := templateDir.ReadFile("templates/footer.html")
- if err != nil {
- return nil, errors.Wrap(err, "reading footer template")
- }
-
- t := htemplate.New(templateName)
- if _, err = t.Parse(string(content)); err != nil {
- return nil, errors.Wrap(err, "parsing template")
- }
- if _, err = t.Parse(string(headerContent)); err != nil {
- return nil, errors.Wrap(err, "parsing template")
- }
- if _, err = t.Parse(string(footerContent)); err != nil {
- return nil, errors.Wrap(err, "parsing template")
- }
-
- return t, nil
-}
-
// initTextTmpl returns a template instance by parsing the template with the given name
func initTextTmpl(templateName string) (template, error) {
- filename := fmt.Sprintf("templates/src/%s.txt", templateName)
+ filename := fmt.Sprintf("%s.txt", templateName)
- content, err := templateDir.ReadFile(filename)
+ content, err := templates.Files.ReadFile(filename)
if err != nil {
return nil, errors.Wrap(err, "reading template")
}
@@ -165,7 +119,7 @@ func initTextTmpl(templateName string) (template, error) {
}
// Execute executes the template with the given name with the givn data
-func (tmpl Templates) Execute(name, kind string, data interface{}) (string, error) {
+func (tmpl Templates) Execute(name, kind string, data any) (string, error) {
t, err := tmpl.get(name, kind)
if err != nil {
return "", errors.Wrap(err, "getting template")
@@ -176,15 +130,5 @@ func (tmpl Templates) Execute(name, kind string, data interface{}) (string, erro
return "", errors.Wrap(err, "executing the template")
}
- // If HTML email, inline the CSS rules
- if kind == EmailKindHTML {
- html, err := inliner.Inline(buf.String())
- if err != nil {
- return "", errors.Wrap(err, "inlining the css rules")
- }
-
- return html, nil
- }
-
return buf.String(), nil
}
diff --git a/pkg/server/mailer/mailer_test.go b/pkg/server/mailer/mailer_test.go
index 6f24b4fb..df95b1f9 100644
--- a/pkg/server/mailer/mailer_test.go
+++ b/pkg/server/mailer/mailer_test.go
@@ -26,6 +26,26 @@ import (
"github.com/pkg/errors"
)
+func TestAllTemplatesInitialized(t *testing.T) {
+ tmpl := NewTemplates()
+
+ emailTypes := []string{
+ EmailTypeResetPassword,
+ EmailTypeResetPasswordAlert,
+ EmailTypeEmailVerification,
+ EmailTypeWelcome,
+ }
+
+ for _, emailType := range emailTypes {
+ t.Run(emailType, func(t *testing.T) {
+ _, err := tmpl.get(emailType, EmailKindText)
+ if err != nil {
+ t.Errorf("template %s not initialized: %v", emailType, err)
+ }
+ })
+ }
+}
+
func TestEmailVerificationEmail(t *testing.T) {
testCases := []struct {
token string
@@ -101,3 +121,79 @@ func TestResetPasswordEmail(t *testing.T) {
})
}
}
+
+func TestWelcomeEmail(t *testing.T) {
+ testCases := []struct {
+ accountEmail string
+ webURL string
+ }{
+ {
+ accountEmail: "test@example.com",
+ webURL: "http://localhost:3000",
+ },
+ {
+ accountEmail: "user@example.org",
+ webURL: "http://localhost:3001",
+ },
+ }
+
+ tmpl := NewTemplates()
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("with WebURL %s and email %s", tc.webURL, tc.accountEmail), func(t *testing.T) {
+ dat := WelcomeTmplData{
+ AccountEmail: tc.accountEmail,
+ WebURL: tc.webURL,
+ }
+ body, err := tmpl.Execute(EmailTypeWelcome, EmailKindText, dat)
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "executing"))
+ }
+
+ if ok := strings.Contains(body, tc.webURL); !ok {
+ t.Errorf("email body did not contain %s", tc.webURL)
+ }
+ if ok := strings.Contains(body, tc.accountEmail); !ok {
+ t.Errorf("email body did not contain %s", tc.accountEmail)
+ }
+ })
+ }
+}
+
+func TestResetPasswordAlertEmail(t *testing.T) {
+ testCases := []struct {
+ accountEmail string
+ webURL string
+ }{
+ {
+ accountEmail: "test@example.com",
+ webURL: "http://localhost:3000",
+ },
+ {
+ accountEmail: "user@example.org",
+ webURL: "http://localhost:3001",
+ },
+ }
+
+ tmpl := NewTemplates()
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("with WebURL %s and email %s", tc.webURL, tc.accountEmail), func(t *testing.T) {
+ dat := EmailResetPasswordAlertTmplData{
+ AccountEmail: tc.accountEmail,
+ WebURL: tc.webURL,
+ }
+ body, err := tmpl.Execute(EmailTypeResetPasswordAlert, EmailKindText, dat)
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "executing"))
+ }
+
+ if ok := strings.Contains(body, tc.webURL); !ok {
+ t.Errorf("email body did not contain %s", tc.webURL)
+ }
+ if ok := strings.Contains(body, tc.accountEmail); !ok {
+ t.Errorf("email body did not contain %s", tc.accountEmail)
+ }
+ })
+ }
+}
diff --git a/pkg/server/mailer/templates/.env.dev b/pkg/server/mailer/templates/.env.dev
deleted file mode 100644
index 7808cb4a..00000000
--- a/pkg/server/mailer/templates/.env.dev
+++ /dev/null
@@ -1,12 +0,0 @@
-DBHost=localhost
-DBPort=5433
-DBName=dnote
-DBUser=postgres
-DBPassword=
-
-SmtpUsername=mock-SmtpUsername
-SmtpPassword=mock-SmtpPassword
-SmtpHost=mock-SmtpHost
-
-WebURL=http://localhost:3000
-DisableRegistration=false
diff --git a/pkg/server/mailer/templates/.gitignore b/pkg/server/mailer/templates/.gitignore
deleted file mode 100644
index f8a26871..00000000
--- a/pkg/server/mailer/templates/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-templates
diff --git a/pkg/server/mailer/templates/README.md b/pkg/server/mailer/templates/README.md
deleted file mode 100644
index 9329442a..00000000
--- a/pkg/server/mailer/templates/README.md
+++ /dev/null
@@ -1,13 +0,0 @@
-# templates
-
-Email templates
-
-* `/src` contains templates.
-
-## Development
-
-Run the server to develop templates locally.
-
-```
-./dev.sh
-```
diff --git a/pkg/server/mailer/templates/dev.sh b/pkg/server/mailer/templates/dev.sh
deleted file mode 100755
index 035220b1..00000000
--- a/pkg/server/mailer/templates/dev.sh
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/usr/bin/env bash
-set -eux
-
-PID=""
-
-function cleanup {
- if [ "$PID" != "" ]; then
- kill "$PID"
- fi
-}
-trap cleanup EXIT
-
-while true; do
- go build main.go
- ./main &
- PID=$!
- inotifywait -r -e modify .
- kill $PID
-done
-
-
diff --git a/pkg/server/mailer/templates/main.go b/pkg/server/mailer/templates/main.go
deleted file mode 100644
index 90ed940c..00000000
--- a/pkg/server/mailer/templates/main.go
+++ /dev/null
@@ -1,144 +0,0 @@
-/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
- *
- * This file is part of Dnote.
- *
- * Dnote is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * Dnote is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with Dnote. If not, see .
- */
-
-package main
-
-import (
- "log"
- "net/http"
-
- "github.com/dnote/dnote/pkg/server/config"
- "github.com/dnote/dnote/pkg/server/database"
- "github.com/dnote/dnote/pkg/server/mailer"
- "gorm.io/gorm"
- "github.com/joho/godotenv"
- _ "github.com/lib/pq"
-)
-
-func (c Context) passwordResetHandler(w http.ResponseWriter, r *http.Request) {
- data := mailer.EmailResetPasswordTmplData{
- AccountEmail: "alice@example.com",
- Token: "testToken",
- WebURL: "http://localhost:3000",
- }
- body, err := c.Tmpl.Execute(mailer.EmailTypeResetPassword, mailer.EmailKindText, data)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- w.Write([]byte(body))
-}
-
-func (c Context) passwordResetAlertHandler(w http.ResponseWriter, r *http.Request) {
- data := mailer.EmailResetPasswordAlertTmplData{
- AccountEmail: "alice@example.com",
- WebURL: "http://localhost:3000",
- }
- body, err := c.Tmpl.Execute(mailer.EmailTypeResetPasswordAlert, mailer.EmailKindText, data)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- w.Write([]byte(body))
-}
-
-func (c Context) emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
- data := mailer.EmailVerificationTmplData{
- Token: "testToken",
- WebURL: "http://localhost:3000",
- }
- body, err := c.Tmpl.Execute(mailer.EmailTypeEmailVerification, mailer.EmailKindText, data)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- w.Write([]byte(body))
-}
-
-func (c Context) welcomeHandler(w http.ResponseWriter, r *http.Request) {
- data := mailer.WelcomeTmplData{
- AccountEmail: "alice@example.com",
- WebURL: "http://localhost:3000",
- }
- body, err := c.Tmpl.Execute(mailer.EmailTypeWelcome, mailer.EmailKindText, data)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- w.Write([]byte(body))
-}
-
-func (c Context) inactiveHandler(w http.ResponseWriter, r *http.Request) {
- data := mailer.InactiveReminderTmplData{
- SampleNoteUUID: "some-uuid",
- WebURL: "http://localhost:3000",
- Token: "some-random-token",
- }
- body, err := c.Tmpl.Execute(mailer.EmailTypeInactiveReminder, mailer.EmailKindText, data)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- w.Write([]byte(body))
-}
-
-func (c Context) homeHandler(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("Email development server is running."))
-}
-
-func init() {
- err := godotenv.Load(".env.dev")
- if err != nil {
- panic(err)
- }
-}
-
-// Context is a context holding global information
-type Context struct {
- DB *gorm.DB
- Tmpl mailer.Templates
-}
-
-func main() {
- c := config.Load()
- db := database.Open(c)
- defer func() {
- sqlDB, err := db.DB()
- if err == nil {
- sqlDB.Close()
- }
- }()
-
- log.Println("Email template development server running on http://127.0.0.1:2300")
-
- tmpl := mailer.NewTemplates()
- ctx := Context{DB: db, Tmpl: tmpl}
-
- http.HandleFunc("/", ctx.homeHandler)
- http.HandleFunc("/email-verification", ctx.emailVerificationHandler)
- http.HandleFunc("/password-reset", ctx.passwordResetHandler)
- http.HandleFunc("/password-reset-alert", ctx.passwordResetAlertHandler)
- http.HandleFunc("/welcome", ctx.welcomeHandler)
- http.HandleFunc("/inactive-reminder", ctx.inactiveHandler)
- log.Fatal(http.ListenAndServe(":2300", nil))
-}
diff --git a/pkg/server/mailer/templates/reset_password.txt b/pkg/server/mailer/templates/reset_password.txt
new file mode 100644
index 00000000..9053a493
--- /dev/null
+++ b/pkg/server/mailer/templates/reset_password.txt
@@ -0,0 +1,5 @@
+You are receiving this because you requested to reset the password of the '{{ .AccountEmail }}' Dnote account.
+
+Please click on the following link, or paste this into your browser to complete the process:
+
+ {{ .WebURL }}/password-reset/{{ .Token }}
diff --git a/pkg/server/mailer/templates/src/reset_password_alert.txt b/pkg/server/mailer/templates/reset_password_alert.txt
similarity index 50%
rename from pkg/server/mailer/templates/src/reset_password_alert.txt
rename to pkg/server/mailer/templates/reset_password_alert.txt
index 3aa9bdd6..16957375 100644
--- a/pkg/server/mailer/templates/src/reset_password_alert.txt
+++ b/pkg/server/mailer/templates/reset_password_alert.txt
@@ -2,7 +2,7 @@ Hi,
This email is to notify you that the password for your Dnote account "{{ .AccountEmail }}" has changed.
-If you did not initiate this password change, please notify us by replying, and reset your password at {{ .WebURL }}/password-reset
+If you did not initiate this password change, reset your password at {{ .WebURL }}/password-reset.
Thanks.
diff --git a/pkg/server/mailer/templates/scripts/run.sh b/pkg/server/mailer/templates/scripts/run.sh
deleted file mode 100755
index fd8e8ac5..00000000
--- a/pkg/server/mailer/templates/scripts/run.sh
+++ /dev/null
@@ -1 +0,0 @@
-CompileDaemon -directory=. -command="./templates" -include="*.html"
diff --git a/pkg/server/mailer/templates/src/inactive.txt b/pkg/server/mailer/templates/src/inactive.txt
deleted file mode 100644
index b6f4d508..00000000
--- a/pkg/server/mailer/templates/src/inactive.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-Hi, nothing has been added to your Dnote for some time.
-
-What about revisiting one of your previous notes? {{ .WebURL }}/notes/{{ .SampleNoteUUID }}
-
-You can add new notes at {{ .WebURL }}/new or using Dnote apps.
-
-- Dnote team
-
-UNSUBSCRIBE: {{ .WebURL }}/settings/notifications?token={{ .Token }}
diff --git a/pkg/server/mailer/templates/src/reset_password.txt b/pkg/server/mailer/templates/src/reset_password.txt
deleted file mode 100644
index 3bc34850..00000000
--- a/pkg/server/mailer/templates/src/reset_password.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-You are receiving this because you (or someone else) requested to reset the password of the '{{ .AccountEmail }}' Dnote account.
-
-Please click on the following link, or paste this into your browser to complete the process:
-
- {{ .WebURL }}/password-reset/{{ .Token }}
-
-You can reply to this message, if you have questions.
-
-- Dnote team
diff --git a/pkg/server/app/main_test.go b/pkg/server/mailer/templates/templates.go
similarity index 78%
rename from pkg/server/app/main_test.go
rename to pkg/server/mailer/templates/templates.go
index d757da42..c59d3a14 100644
--- a/pkg/server/app/main_test.go
+++ b/pkg/server/mailer/templates/templates.go
@@ -16,20 +16,10 @@
* along with Dnote. If not, see .
*/
-package app
+// Package mailer provides a functionality to send emails
+package templates
-import (
- "os"
- "testing"
+import "embed"
- "github.com/dnote/dnote/pkg/server/testutils"
-)
-
-func TestMain(m *testing.M) {
- testutils.InitTestDB()
-
- code := m.Run()
- testutils.ClearData(testutils.DB)
-
- os.Exit(code)
-}
+//go:embed *.txt
+var Files embed.FS
diff --git a/pkg/server/mailer/templates/src/verify_email.txt b/pkg/server/mailer/templates/verify_email.txt
similarity index 72%
rename from pkg/server/mailer/templates/src/verify_email.txt
rename to pkg/server/mailer/templates/verify_email.txt
index a85ab705..c21af88d 100644
--- a/pkg/server/mailer/templates/src/verify_email.txt
+++ b/pkg/server/mailer/templates/verify_email.txt
@@ -1,9 +1,5 @@
-Hi.
+Hi,
Welcome to Dnote! To verify your email, visit the following link:
{{ .WebURL }}/verify-email/{{ .Token }}
-
-Thanks for using Dnote.
-
-- Dnote team
diff --git a/pkg/server/mailer/templates/src/welcome.txt b/pkg/server/mailer/templates/welcome.txt
similarity index 83%
rename from pkg/server/mailer/templates/src/welcome.txt
rename to pkg/server/mailer/templates/welcome.txt
index 7a33207a..72d0fdf0 100644
--- a/pkg/server/mailer/templates/src/welcome.txt
+++ b/pkg/server/mailer/templates/welcome.txt
@@ -10,7 +10,3 @@ If you ever forget your password, you can reset it at {{ .WebURL }}/password-res
SOURCE CODE
Dnote is open source and you can see the source code at https://github.com/dnote/dnote
-
-Feel free to reply anytime. Thanks for using Dnote.
-
-- Dnote team
diff --git a/pkg/server/mailer/tokens.go b/pkg/server/mailer/tokens.go
index 7d78725f..0f751a4e 100644
--- a/pkg/server/mailer/tokens.go
+++ b/pkg/server/mailer/tokens.go
@@ -21,19 +21,17 @@ package mailer
import (
"crypto/rand"
"encoding/base64"
- "errors"
"github.com/dnote/dnote/pkg/server/database"
- pkgErrors "github.com/pkg/errors"
+ "github.com/pkg/errors"
"gorm.io/gorm"
)
func generateRandomToken(bits int) (string, error) {
b := make([]byte, bits)
- _, err := rand.Read(b)
- if err != nil {
- return "", pkgErrors.Wrap(err, "generating random bytes")
+ if _, err := rand.Read(b); err != nil {
+ return "", errors.Wrap(err, "generating random bytes")
}
return base64.URLEncoding.EncodeToString(b), nil
@@ -49,7 +47,7 @@ func GetToken(db *gorm.DB, userID int, kind string) (database.Token, error) {
tokenVal, genErr := generateRandomToken(16)
if genErr != nil {
- return tok, pkgErrors.Wrap(genErr, "generating token value")
+ return tok, errors.Wrap(genErr, "generating token value")
}
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -59,12 +57,12 @@ func GetToken(db *gorm.DB, userID int, kind string) (database.Token, error) {
Value: tokenVal,
}
if err := db.Save(&tok).Error; err != nil {
- return tok, pkgErrors.Wrap(err, "saving token")
+ return tok, errors.Wrap(err, "saving token")
}
return tok, nil
} else if err != nil {
- return tok, pkgErrors.Wrap(err, "finding token")
+ return tok, errors.Wrap(err, "finding token")
}
return tok, nil
diff --git a/pkg/server/mailer/tokens_test.go b/pkg/server/mailer/tokens_test.go
new file mode 100644
index 00000000..72a85fc6
--- /dev/null
+++ b/pkg/server/mailer/tokens_test.go
@@ -0,0 +1,83 @@
+/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
+ *
+ * This file is part of Dnote.
+ *
+ * Dnote is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Dnote is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Dnote. If not, see .
+ */
+
+package mailer
+
+import (
+ "testing"
+
+ "github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/testutils"
+)
+
+func TestGetToken(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ userID := 1
+ tokenType := "email_verification"
+
+ t.Run("creates new token", func(t *testing.T) {
+ token, err := GetToken(db, userID, tokenType)
+ if err != nil {
+ t.Fatalf("GetToken failed: %v", err)
+ }
+
+ if token.UserID != userID {
+ t.Errorf("expected UserID %d, got %d", userID, token.UserID)
+ }
+ if token.Type != tokenType {
+ t.Errorf("expected Type %s, got %s", tokenType, token.Type)
+ }
+ if token.Value == "" {
+ t.Error("expected non-empty token Value")
+ }
+ if token.UsedAt != nil {
+ t.Error("expected UsedAt to be nil for new token")
+ }
+ })
+
+ t.Run("reuses unused token", func(t *testing.T) {
+ // Get token again - should return the same one
+ token2, err := GetToken(db, userID, tokenType)
+ if err != nil {
+ t.Fatalf("second GetToken failed: %v", err)
+ }
+
+ // Get first token to compare
+ var token1 database.Token
+ if err := db.Where("user_id = ? AND type = ?", userID, tokenType).First(&token1).Error; err != nil {
+ t.Fatalf("failed to get first token: %v", err)
+ }
+
+ if token1.ID != token2.ID {
+ t.Errorf("expected same token ID %d, got %d", token1.ID, token2.ID)
+ }
+ if token1.Value != token2.Value {
+ t.Errorf("expected same token Value %s, got %s", token1.Value, token2.Value)
+ }
+
+ // Verify only one token exists in database
+ var count int64
+ if err := db.Model(&database.Token{}).Where("user_id = ? AND type = ?", userID, tokenType).Count(&count).Error; err != nil {
+ t.Fatalf("failed to count tokens: %v", err)
+ }
+ if count != 1 {
+ t.Errorf("expected 1 token in database, got %d", count)
+ }
+ })
+}
diff --git a/pkg/server/mailer/types.go b/pkg/server/mailer/types.go
index da9a448c..3a371911 100644
--- a/pkg/server/mailer/types.go
+++ b/pkg/server/mailer/types.go
@@ -42,16 +42,3 @@ type WelcomeTmplData struct {
AccountEmail string
WebURL string
}
-
-// InactiveReminderTmplData is a template data for welcome emails
-type InactiveReminderTmplData struct {
- SampleNoteUUID string
- WebURL string
- Token string
-}
-
-// EmailTypeSubscriptionConfirmationTmplData is a template data for reset password emails
-type EmailTypeSubscriptionConfirmationTmplData struct {
- AccountEmail string
- WebURL string
-}
diff --git a/pkg/server/main.go b/pkg/server/main.go
index d8df31e0..4e36b96d 100644
--- a/pkg/server/main.go
+++ b/pkg/server/main.go
@@ -21,8 +21,8 @@ package main
import (
"flag"
"fmt"
- "log"
"net/http"
+ "os"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
@@ -30,54 +30,81 @@ import (
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/controllers"
"github.com/dnote/dnote/pkg/server/database"
- "github.com/dnote/dnote/pkg/server/job"
+ "github.com/dnote/dnote/pkg/server/log"
"github.com/dnote/dnote/pkg/server/mailer"
- "gorm.io/driver/postgres"
- "gorm.io/gorm"
-
"github.com/pkg/errors"
+ "gorm.io/gorm"
)
-var port = flag.String("port", "3000", "port to connect to")
-
-func initDB(c config.Config) *gorm.DB {
- db, err := gorm.Open(postgres.Open(c.DB.GetConnectionStr()), &gorm.Config{})
- if err != nil {
- panic(errors.Wrap(err, "opening database connection"))
- }
+func initDB(dbPath string) *gorm.DB {
+ db := database.Open(dbPath)
database.InitSchema(db)
+ database.Migrate(db)
return db
}
func initApp(cfg config.Config) app.App {
- db := initDB(cfg)
+ db := initDB(cfg.DBPath)
+
+ emailBackend, err := mailer.NewDefaultBackend(cfg.IsProd())
+ if err != nil {
+ emailBackend = &mailer.DefaultBackend{Enabled: false}
+ } else {
+ log.Info("Email backend configured")
+ }
return app.App{
- DB: db,
- Clock: clock.New(),
- EmailTemplates: mailer.NewTemplates(),
- EmailBackend: &mailer.SimpleBackendImplementation{},
- Config: cfg,
- HTTP500Page: cfg.HTTP500Page,
+ DB: db,
+ Clock: clock.New(),
+ EmailTemplates: mailer.NewTemplates(),
+ EmailBackend: emailBackend,
+ HTTP500Page: cfg.HTTP500Page,
+ AppEnv: cfg.AppEnv,
+ WebURL: cfg.WebURL,
+ DisableRegistration: cfg.DisableRegistration,
+ Port: cfg.Port,
+ DBPath: cfg.DBPath,
+ AssetBaseURL: cfg.AssetBaseURL,
}
}
-func runJob(a app.App) error {
- runner, err := job.NewRunner(a.DB, a.Clock, a.EmailTemplates, a.EmailBackend, a.Config)
+func startCmd(args []string) {
+ startFlags := flag.NewFlagSet("start", flag.ExitOnError)
+ startFlags.Usage = func() {
+ fmt.Printf(`Usage:
+ dnote-server start [flags]
+
+Flags:
+`)
+ startFlags.PrintDefaults()
+ }
+
+ appEnv := startFlags.String("appEnv", "", "Application environment (env: APP_ENV, default: PRODUCTION)")
+ port := startFlags.String("port", "", "Server port (env: PORT, default: 3000)")
+ webURL := startFlags.String("webUrl", "", "Full URL to server without trailing slash (env: WebURL, example: https://example.com)")
+ dbPath := startFlags.String("dbPath", "", "Path to SQLite database file (env: DBPath, default: $XDG_DATA_HOME/dnote/server.db)")
+ disableRegistration := startFlags.Bool("disableRegistration", false, "Disable user registration (env: DisableRegistration, default: false)")
+ logLevel := startFlags.String("logLevel", "", "Log level: debug, info, warn, or error (env: LOG_LEVEL, default: info)")
+
+ startFlags.Parse(args)
+
+ cfg, err := config.New(config.Params{
+ AppEnv: *appEnv,
+ Port: *port,
+ WebURL: *webURL,
+ DBPath: *dbPath,
+ DisableRegistration: *disableRegistration,
+ LogLevel: *logLevel,
+ })
if err != nil {
- return errors.Wrap(err, "getting a job runner")
- }
- if err := runner.Do(); err != nil {
- return errors.Wrap(err, "running job")
+ fmt.Printf("Error: %s\n\n", err)
+ startFlags.Usage()
+ os.Exit(1)
}
- return nil
-}
-
-func startCmd() {
- cfg := config.Load()
- cfg.SetAssetBaseURL("/static")
+ // Set log level
+ log.SetLevel(cfg.LogLevel)
app := initApp(cfg)
defer func() {
@@ -87,13 +114,6 @@ func startCmd() {
}
}()
- if err := database.Migrate(app.DB); err != nil {
- panic(errors.Wrap(err, "running migrations"))
- }
- if err := runJob(app); err != nil {
- panic(errors.Wrap(err, "running job"))
- }
-
ctl := controllers.New(&app)
rc := controllers.RouteConfig{
WebRoutes: controllers.NewWebRoutes(&app, ctl),
@@ -106,8 +126,15 @@ func startCmd() {
panic(errors.Wrap(err, "initializing router"))
}
- log.Printf("Dnote version %s is running on port %s", buildinfo.Version, *port)
- log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%s", *port), r))
+ log.WithFields(log.Fields{
+ "version": buildinfo.Version,
+ "port": cfg.Port,
+ }).Info("Dnote server starting")
+
+ if err := http.ListenAndServe(fmt.Sprintf(":%s", cfg.Port), r); err != nil {
+ log.ErrorWrap(err, "server failed")
+ os.Exit(1)
+ }
}
func versionCmd() {
@@ -115,29 +142,33 @@ func versionCmd() {
}
func rootCmd() {
- fmt.Printf(`Dnote server - a simple personal knowledge base
+ fmt.Printf(`Dnote server - a simple command line notebook
Usage:
- dnote-server [command]
+ dnote-server [command] [flags]
Available commands:
- start: Start the server
+ start: Start the server (use 'dnote-server start --help' for flags)
version: Print the version
`)
}
func main() {
- flag.Parse()
- cmd := flag.Arg(0)
+ if len(os.Args) < 2 {
+ rootCmd()
+ return
+ }
+
+ cmd := os.Args[1]
switch cmd {
- case "":
- rootCmd()
case "start":
- startCmd()
+ startCmd(os.Args[2:])
case "version":
versionCmd()
default:
- fmt.Printf("Unknown command %s", cmd)
+ fmt.Printf("Unknown command %s\n", cmd)
+ rootCmd()
+ os.Exit(1)
}
}
diff --git a/pkg/server/middleware/auth.go b/pkg/server/middleware/auth.go
index 28af8760..984079e4 100644
--- a/pkg/server/middleware/auth.go
+++ b/pkg/server/middleware/auth.go
@@ -22,19 +22,17 @@ import (
"errors"
"net/http"
"net/url"
- "strings"
"time"
- "github.com/dnote/dnote/pkg/server/app"
"github.com/dnote/dnote/pkg/server/context"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
- "gorm.io/gorm"
pkgErrors "github.com/pkg/errors"
+ "gorm.io/gorm"
)
-func authWithToken(db *gorm.DB, r *http.Request, tokenType string, p *AuthParams) (database.User, database.Token, bool, error) {
+func authWithToken(db *gorm.DB, r *http.Request, tokenType string) (database.User, database.Token, bool, error) {
var user database.User
var token database.Token
@@ -62,32 +60,17 @@ func authWithToken(db *gorm.DB, r *http.Request, tokenType string, p *AuthParams
return user, token, true, nil
}
-// Cors allows browser extensions to load resources
-func Cors(next http.HandlerFunc) http.HandlerFunc {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- origin := r.Header.Get("Origin")
-
- // Allow browser extensions
- if strings.HasPrefix(origin, "moz-extension://") || strings.HasPrefix(origin, "chrome-extension://") {
- w.Header().Set("Access-Control-Allow-Origin", origin)
- }
-
- next.ServeHTTP(w, r)
- })
-}
-
// AuthParams is the params for the authentication middleware
type AuthParams struct {
- ProOnly bool
RedirectGuestsToLogin bool
}
// Auth is an authentication middleware
-func Auth(a *app.App, next http.HandlerFunc, p *AuthParams) http.HandlerFunc {
- next = WithAccount(a, next)
+func Auth(db *gorm.DB, next http.HandlerFunc, p *AuthParams) http.HandlerFunc {
+ next = WithAccount(db, next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user, ok, err := AuthWithSession(a.DB, r)
+ user, ok, err := AuthWithSession(db, r)
if !ok {
if p != nil && p.RedirectGuestsToLogin {
@@ -107,25 +90,18 @@ func Auth(a *app.App, next http.HandlerFunc, p *AuthParams) http.HandlerFunc {
return
}
- if p != nil && p.ProOnly {
- if !user.Cloud {
- RespondForbidden(w)
- return
- }
- }
-
ctx := context.WithUser(r.Context(), &user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
-func WithAccount(a *app.App, next http.HandlerFunc) http.HandlerFunc {
+func WithAccount(db *gorm.DB, next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := context.User(r.Context())
var account database.Account
- if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
+ if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
DoError(w, "finding account", err, http.StatusInternalServerError)
return
}
@@ -137,9 +113,9 @@ func WithAccount(a *app.App, next http.HandlerFunc) http.HandlerFunc {
}
// TokenAuth is an authentication middleware with token
-func TokenAuth(a *app.App, next http.HandlerFunc, tokenType string, p *AuthParams) http.HandlerFunc {
+func TokenAuth(db *gorm.DB, next http.HandlerFunc, tokenType string, p *AuthParams) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user, token, ok, err := authWithToken(a.DB, r, tokenType, p)
+ user, token, ok, err := authWithToken(db, r, tokenType)
if err != nil {
// log the error and continue
log.ErrorWrap(err, "authenticating with token")
@@ -151,7 +127,7 @@ func TokenAuth(a *app.App, next http.HandlerFunc, tokenType string, p *AuthParam
ctx = context.WithToken(ctx, &token)
} else {
// If token-based auth fails, fall back to session-based auth
- user, ok, err = AuthWithSession(a.DB, r)
+ user, ok, err = AuthWithSession(db, r)
if err != nil {
DoError(w, "authenticating with session", err, http.StatusInternalServerError)
return
@@ -163,13 +139,6 @@ func TokenAuth(a *app.App, next http.HandlerFunc, tokenType string, p *AuthParam
}
}
- if p != nil && p.ProOnly {
- if !user.Cloud {
- RespondForbidden(w)
- return
- }
- }
-
ctx = context.WithUser(ctx, &user)
next.ServeHTTP(w, r.WithContext(ctx))
})
@@ -211,9 +180,9 @@ func AuthWithSession(db *gorm.DB, r *http.Request) (database.User, bool, error)
return user, true, nil
}
-func GuestOnly(a *app.App, next http.HandlerFunc) http.HandlerFunc {
+func GuestOnly(db *gorm.DB, next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _, ok, err := AuthWithSession(a.DB, r)
+ _, ok, err := AuthWithSession(db, r)
if err != nil {
// log the error and continue
log.ErrorWrap(err, "authenticating with session")
diff --git a/pkg/server/middleware/auth_test.go b/pkg/server/middleware/auth_test.go
new file mode 100644
index 00000000..8451ae5d
--- /dev/null
+++ b/pkg/server/middleware/auth_test.go
@@ -0,0 +1,235 @@
+/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
+ *
+ * This file is part of Dnote.
+ *
+ * Dnote is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Dnote is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Dnote. If not, see .
+ */
+
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/dnote/dnote/pkg/assert"
+ "github.com/dnote/dnote/pkg/server/database"
+ "github.com/dnote/dnote/pkg/server/testutils"
+)
+
+func TestGuestOnly(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }
+
+ server := httptest.NewServer(GuestOnly(db, handler))
+ defer server.Close()
+
+ t.Run("guest", func(t *testing.T) {
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
+ })
+
+ t.Run("logged in", func(t *testing.T) {
+ user := testutils.SetupUserData(db)
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ res := testutils.HTTPAuthDo(t, db, req, user)
+
+ assert.Equal(t, res.StatusCode, http.StatusFound, "status code mismatch")
+ assert.Equal(t, res.Header.Get("Location"), "/", "location mismatch")
+ })
+
+ t.Run("error getting credential", func(t *testing.T) {
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ req.Header.Set("Authorization", "InvalidFormat")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
+ })
+}
+
+func TestAuth(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ user := testutils.SetupUserData(db)
+ testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
+
+ session := database.Session{
+ Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
+ UserID: user.ID,
+ ExpiresAt: time.Now().Add(time.Hour * 24),
+ }
+ testutils.MustExec(t, db.Save(&session), "preparing session")
+ expiredSession := database.Session{
+ Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
+ UserID: user.ID,
+ ExpiresAt: time.Now().Add(-time.Hour * 24),
+ }
+ testutils.MustExec(t, db.Save(&expiredSession), "preparing expired session")
+
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }
+
+ t.Run("valid session with header", func(t *testing.T) {
+ server := httptest.NewServer(Auth(db, handler, nil))
+ defer server.Close()
+
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ req.Header.Set("Authorization", "Bearer "+session.Key)
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
+ })
+
+ t.Run("expired session with header", func(t *testing.T) {
+ server := httptest.NewServer(Auth(db, handler, nil))
+ defer server.Close()
+
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ req.Header.Set("Authorization", "Bearer "+expiredSession.Key)
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
+ })
+
+ t.Run("invalid session with header", func(t *testing.T) {
+ server := httptest.NewServer(Auth(db, handler, nil))
+ defer server.Close()
+
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ req.Header.Set("Authorization", "Bearer someInvalidSessionKey=")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
+ })
+
+ t.Run("valid session with cookie", func(t *testing.T) {
+ server := httptest.NewServer(Auth(db, handler, nil))
+ defer server.Close()
+
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ req.AddCookie(&http.Cookie{
+ Name: "id",
+ Value: session.Key,
+ HttpOnly: true,
+ })
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
+ })
+
+ t.Run("expired session with cookie", func(t *testing.T) {
+ server := httptest.NewServer(Auth(db, handler, nil))
+ defer server.Close()
+
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ req.AddCookie(&http.Cookie{
+ Name: "id",
+ Value: expiredSession.Key,
+ HttpOnly: true,
+ })
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
+ })
+
+ t.Run("no auth", func(t *testing.T) {
+ server := httptest.NewServer(Auth(db, handler, nil))
+ defer server.Close()
+
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
+ })
+
+ t.Run("redirect guests to login", func(t *testing.T) {
+ server := httptest.NewServer(Auth(db, handler, &AuthParams{RedirectGuestsToLogin: true}))
+ defer server.Close()
+
+ req := testutils.MakeReq(server.URL, "GET", "/settings", "")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusFound, "status code mismatch")
+ assert.Equal(t, res.Header.Get("Location"), "/login?referrer=%2Fsettings", "location mismatch")
+ })
+}
+
+func TestTokenAuth(t *testing.T) {
+ db := testutils.InitMemoryDB(t)
+
+ user := testutils.SetupUserData(db)
+ tok := database.Token{
+ UserID: user.ID,
+ Type: database.TokenTypeEmailVerification,
+ Value: "xpwFnc0MdllFUePDq9DLeQ==",
+ }
+ testutils.MustExec(t, db.Save(&tok), "preparing token")
+ session := database.Session{
+ Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
+ UserID: user.ID,
+ ExpiresAt: time.Now().Add(time.Hour * 24),
+ }
+ testutils.MustExec(t, db.Save(&session), "preparing session")
+
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }
+
+ server := httptest.NewServer(TokenAuth(db, handler, database.TokenTypeEmailVerification, nil))
+ defer server.Close()
+
+ t.Run("with token", func(t *testing.T) {
+ req := testutils.MakeReq(server.URL, "GET", "/?token=xpwFnc0MdllFUePDq9DLeQ==", "")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
+ })
+
+ t.Run("with invalid token", func(t *testing.T) {
+ req := testutils.MakeReq(server.URL, "GET", "/?token=someRandomToken==", "")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
+ })
+
+ t.Run("with session header", func(t *testing.T) {
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ req.Header.Set("Authorization", "Bearer "+session.Key)
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
+ })
+
+ t.Run("with invalid session", func(t *testing.T) {
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ req.Header.Set("Authorization", "Bearer someInvalidSessionKey=")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
+ })
+
+ t.Run("without anything", func(t *testing.T) {
+ req := testutils.MakeReq(server.URL, "GET", "/", "")
+ res := testutils.HTTPDo(t, req)
+
+ assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
+ })
+}
diff --git a/pkg/server/middleware/helpers.go b/pkg/server/middleware/helpers.go
index e1059db8..f43d0abb 100644
--- a/pkg/server/middleware/helpers.go
+++ b/pkg/server/middleware/helpers.go
@@ -92,7 +92,6 @@ func DoError(w http.ResponseWriter, msg string, err error, statusCode int) {
// NotSupported is the handler for the route that is no longer supported
func NotSupported(w http.ResponseWriter, r *http.Request) {
http.Error(w, "API version is not supported. Please upgrade your client.", http.StatusGone)
- return
}
// getSessionKeyFromCookie reads and returns a session key from the cookie sent by the
diff --git a/pkg/server/middleware/helpers_test.go b/pkg/server/middleware/helpers_test.go
index 1368dd70..623ec818 100644
--- a/pkg/server/middleware/helpers_test.go
+++ b/pkg/server/middleware/helpers_test.go
@@ -19,16 +19,10 @@
package middleware
import (
- "fmt"
"net/http"
- "net/http/httptest"
"testing"
- "time"
"github.com/dnote/dnote/pkg/assert"
- "github.com/dnote/dnote/pkg/server/app"
- "github.com/dnote/dnote/pkg/server/database"
- "github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
@@ -180,521 +174,3 @@ func TestGetCredential(t *testing.T) {
}
}
-func TestAuthMiddleware(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
-
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
-
- session := database.Session{
- Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
- UserID: user.ID,
- ExpiresAt: time.Now().Add(time.Hour * 24),
- }
- testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
- session2 := database.Session{
- Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
- UserID: user.ID,
- ExpiresAt: time.Now().Add(-time.Hour * 24),
- }
- testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session")
-
- handler := func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }
- a := &app.App{DB: testutils.DB}
- server := httptest.NewServer(Auth(a, handler, nil))
- defer server.Close()
-
- t.Run("with header", func(t *testing.T) {
- testCases := []struct {
- header string
- expectedStatus int
- }{
- {
- header: fmt.Sprintf("Bearer %s", session.Key),
- expectedStatus: http.StatusOK,
- },
- {
- header: fmt.Sprintf("Bearer %s", session2.Key),
- expectedStatus: http.StatusUnauthorized,
- },
- {
- header: fmt.Sprintf("Bearer someInvalidSessionKey="),
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.header, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
- req.Header.Set("Authorization", tc.header)
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("with cookie", func(t *testing.T) {
- testCases := []struct {
- cookie *http.Cookie
- expectedStatus int
- }{
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: session.Key,
- HttpOnly: true,
- },
- expectedStatus: http.StatusOK,
- },
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: session2.Key,
- HttpOnly: true,
- },
- expectedStatus: http.StatusUnauthorized,
- },
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: "someInvalidSessionKey=",
- HttpOnly: true,
- },
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.cookie.Value, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
- req.AddCookie(tc.cookie)
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("without anything", func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
- })
-}
-
-func TestAuthMiddleware_ProOnly(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
-
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
- session := database.Session{
- Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
- UserID: user.ID,
- ExpiresAt: time.Now().Add(time.Hour * 24),
- }
- testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
-
- handler := func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }
-
- a := &app.App{DB: testutils.DB}
- server := httptest.NewServer(Auth(a, handler, &AuthParams{
- ProOnly: true,
- }))
-
- defer server.Close()
-
- t.Run("with header", func(t *testing.T) {
- testCases := []struct {
- header string
- expectedStatus int
- }{
- {
- header: fmt.Sprintf("Bearer %s", session.Key),
- expectedStatus: http.StatusForbidden,
- },
- {
- header: fmt.Sprintf("Bearer someInvalidSessionKey="),
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.header, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
- req.Header.Set("Authorization", tc.header)
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("with cookie", func(t *testing.T) {
- testCases := []struct {
- cookie *http.Cookie
- expectedStatus int
- }{
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: session.Key,
- HttpOnly: true,
- },
- expectedStatus: http.StatusForbidden,
- },
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: "someInvalidSessionKey=",
- HttpOnly: true,
- },
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.cookie.Value, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
- req.AddCookie(tc.cookie)
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-}
-
-func TestAuthMiddleware_RedirectGuestsToLogin(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
-
- handler := func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }
-
- a := &app.App{DB: testutils.DB}
- server := httptest.NewServer(Auth(a, handler, &AuthParams{
- RedirectGuestsToLogin: true,
- }))
-
- defer server.Close()
-
- t.Run("guest", func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, http.StatusFound, "status code mismatch")
- assert.Equal(t, res.Header.Get("Location"), "/login?referrer=%2F", "location header mismatch")
- })
-
- t.Run("logged in user", func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
-
- user := testutils.SetupUserData()
- testutils.SetupAccountData(user, "alice@test.com", "pass1234")
-
- testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
- session := database.Session{
- Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
- UserID: user.ID,
- ExpiresAt: time.Now().Add(time.Hour * 24),
- }
- testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
-
- // execute
- res := testutils.HTTPAuthDo(t, req, user)
- req.Header.Set("Authorization", session.Key)
-
- // test
- assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch")
- assert.Equal(t, res.Header.Get("Location"), "", "location header mismatch")
- })
-
-}
-
-func TestTokenAuthMiddleWare(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
-
- user := testutils.SetupUserData()
- tok := database.Token{
- UserID: user.ID,
- Type: database.TokenTypeEmailPreference,
- Value: "xpwFnc0MdllFUePDq9DLeQ==",
- }
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
- session := database.Session{
- Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
- UserID: user.ID,
- ExpiresAt: time.Now().Add(time.Hour * 24),
- }
- testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
-
- handler := func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }
-
- a := &app.App{DB: testutils.DB}
- server := httptest.NewServer(TokenAuth(a, handler, database.TokenTypeEmailPreference, nil))
- defer server.Close()
-
- t.Run("with token", func(t *testing.T) {
- testCases := []struct {
- token string
- expectedStatus int
- }{
- {
- token: "xpwFnc0MdllFUePDq9DLeQ==",
- expectedStatus: http.StatusOK,
- },
- {
- token: "someRandomToken==",
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.token, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/?token=%s", tc.token), "")
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("with session header", func(t *testing.T) {
- testCases := []struct {
- header string
- expectedStatus int
- }{
- {
- header: fmt.Sprintf("Bearer %s", session.Key),
- expectedStatus: http.StatusOK,
- },
- {
- header: fmt.Sprintf("Bearer someInvalidSessionKey="),
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.header, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
- req.Header.Set("Authorization", tc.header)
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("with session cookie", func(t *testing.T) {
- testCases := []struct {
- cookie *http.Cookie
- expectedStatus int
- }{
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: session.Key,
- HttpOnly: true,
- },
- expectedStatus: http.StatusOK,
- },
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: "someInvalidSessionKey=",
- HttpOnly: true,
- },
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.cookie.Value, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
- req.AddCookie(tc.cookie)
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("without anything", func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
- })
-}
-
-func TestTokenAuthMiddleWare_ProOnly(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
-
- user := testutils.SetupUserData()
- testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session")
- tok := database.Token{
- UserID: user.ID,
- Type: database.TokenTypeEmailPreference,
- Value: "xpwFnc0MdllFUePDq9DLeQ==",
- }
- testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token")
- session := database.Session{
- Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=",
- UserID: user.ID,
- ExpiresAt: time.Now().Add(time.Hour * 24),
- }
- testutils.MustExec(t, testutils.DB.Save(&session), "preparing session")
-
- handler := func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }
-
- a := &app.App{DB: testutils.DB}
- server := httptest.NewServer(TokenAuth(a, handler, database.TokenTypeEmailPreference, &AuthParams{
- ProOnly: true,
- }))
-
- defer server.Close()
-
- t.Run("with token", func(t *testing.T) {
- testCases := []struct {
- token string
- expectedStatus int
- }{
- {
- token: "xpwFnc0MdllFUePDq9DLeQ==",
- expectedStatus: http.StatusForbidden,
- },
- {
- token: "someRandomToken==",
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.token, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/?token=%s", tc.token), "")
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("with session header", func(t *testing.T) {
- testCases := []struct {
- header string
- expectedStatus int
- }{
- {
- header: fmt.Sprintf("Bearer %s", session.Key),
- expectedStatus: http.StatusForbidden,
- },
- {
- header: fmt.Sprintf("Bearer someInvalidSessionKey="),
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.header, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
- req.Header.Set("Authorization", tc.header)
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("with session cookie", func(t *testing.T) {
- testCases := []struct {
- cookie *http.Cookie
- expectedStatus int
- }{
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: session.Key,
- HttpOnly: true,
- },
- expectedStatus: http.StatusForbidden,
- },
- {
- cookie: &http.Cookie{
- Name: "id",
- Value: "someInvalidSessionKey=",
- HttpOnly: true,
- },
- expectedStatus: http.StatusUnauthorized,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.cookie.Value, func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
- req.AddCookie(tc.cookie)
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch")
- })
- }
- })
-
- t.Run("without anything", func(t *testing.T) {
- req := testutils.MakeReq(server.URL, "GET", "/", "")
-
- // execute
- res := testutils.HTTPDo(t, req)
-
- // test
- assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch")
- })
-}
diff --git a/pkg/server/middleware/limit.go b/pkg/server/middleware/limit.go
index 64d27d3e..3b3c3987 100644
--- a/pkg/server/middleware/limit.go
+++ b/pkg/server/middleware/limit.go
@@ -80,7 +80,7 @@ func cleanupVisitors() {
mtx.Lock()
for identifier, v := range visitors {
- if time.Now().Sub(v.lastSeen) > 3*time.Minute {
+ if time.Since(v.lastSeen) > 3*time.Minute {
delete(visitors, identifier)
}
}
@@ -128,7 +128,7 @@ func Limit(next http.Handler) http.HandlerFunc {
func ApplyLimit(h http.HandlerFunc, rateLimit bool) http.Handler {
ret := h
- if rateLimit && os.Getenv("GO_ENV") != "TEST" {
+ if rateLimit && os.Getenv("APP_ENV") != "TEST" {
ret = Limit(ret)
}
diff --git a/pkg/server/middleware/main_test.go b/pkg/server/middleware/main_test.go
deleted file mode 100644
index cd96508c..00000000
--- a/pkg/server/middleware/main_test.go
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
- *
- * This file is part of Dnote.
- *
- * Dnote is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * Dnote is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with Dnote. If not, see .
- */
-
-package middleware
-
-import (
- "os"
- "testing"
-
- "github.com/dnote/dnote/pkg/server/testutils"
-)
-
-func TestMain(m *testing.M) {
- testutils.InitTestDB()
-
- code := m.Run()
- testutils.ClearData(testutils.DB)
-
- os.Exit(code)
-}
diff --git a/pkg/server/middleware/middleware.go b/pkg/server/middleware/middleware.go
index bf422e94..582e8046 100644
--- a/pkg/server/middleware/middleware.go
+++ b/pkg/server/middleware/middleware.go
@@ -20,32 +20,13 @@ package middleware
import (
"net/http"
- "net/url"
"github.com/dnote/dnote/pkg/server/app"
- "github.com/gorilla/schema"
)
// Middleware is a middleware for request handlers
type Middleware func(h http.Handler, app *app.App, rateLimit bool) http.Handler
-type payload struct {
- Method string `schema:"_method"`
-}
-
-func parseValues(values url.Values, dst interface{}) error {
- dec := schema.NewDecoder()
-
- // Ignore CSRF token field
- dec.IgnoreUnknownKeys(true)
-
- if err := dec.Decode(dst, values); err != nil {
- return err
- }
-
- return nil
-}
-
// methodOverrideKey is the form key for overriding the method
var methodOverrideKey = "_method"
diff --git a/pkg/server/operations/main_test.go b/pkg/server/operations/main_test.go
deleted file mode 100644
index 19a59dbb..00000000
--- a/pkg/server/operations/main_test.go
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
- *
- * This file is part of Dnote.
- *
- * Dnote is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * Dnote is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with Dnote. If not, see .
- */
-
-package operations
-
-import (
- "os"
- "testing"
-
- "github.com/dnote/dnote/pkg/server/testutils"
-)
-
-func TestMain(m *testing.M) {
- testutils.InitTestDB()
-
- code := m.Run()
- testutils.ClearData(testutils.DB)
-
- os.Exit(code)
-}
diff --git a/pkg/server/operations/notes.go b/pkg/server/operations/notes.go
index 88cac66d..ee6dd9af 100644
--- a/pkg/server/operations/notes.go
+++ b/pkg/server/operations/notes.go
@@ -19,13 +19,11 @@
package operations
import (
- "errors"
-
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/permissions"
+ "github.com/pkg/errors"
"gorm.io/gorm"
- pkgErrors "github.com/pkg/errors"
)
// GetNote retrieves a note for the given user
@@ -41,7 +39,7 @@ func GetNote(db *gorm.DB, uuid string, user *database.User) (database.Note, bool
if errors.Is(err, gorm.ErrRecordNotFound) {
return zeroNote, false, nil
} else if err != nil {
- return zeroNote, false, pkgErrors.Wrap(err, "finding note")
+ return zeroNote, false, errors.Wrap(err, "finding note")
}
if ok := permissions.ViewNote(user, note); !ok {
diff --git a/pkg/server/operations/notes_test.go b/pkg/server/operations/notes_test.go
index 020bbf75..6124b7e0 100644
--- a/pkg/server/operations/notes_test.go
+++ b/pkg/server/operations/notes_test.go
@@ -28,38 +28,41 @@ import (
)
func TestGetNote(t *testing.T) {
- user := testutils.SetupUserData()
- anotherUser := testutils.SetupUserData()
+ db := testutils.InitMemoryDB(t)
- defer testutils.ClearData(testutils.DB)
+ user := testutils.SetupUserData(db)
+ anotherUser := testutils.SetupUserData(db)
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
privateNote := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "privateNote content",
Deleted: false,
Public: false,
}
- testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
+ testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "privateNote content",
Deleted: false,
Public: true,
}
- testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote")
+ testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote")
var privateNoteRecord, publicNoteRecord database.Note
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote")
- testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote")
+ testutils.MustExec(t, db.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote")
+ testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote")
testCases := []struct {
name string
@@ -107,7 +110,7 @@ func TestGetNote(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- note, ok, err := GetNote(testutils.DB, tc.note.UUID, &tc.user)
+ note, ok, err := GetNote(db, tc.note.UUID, &tc.user)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
@@ -119,29 +122,29 @@ func TestGetNote(t *testing.T) {
}
func TestGetNote_nonexistent(t *testing.T) {
- user := testutils.SetupUserData()
+ db := testutils.InitMemoryDB(t)
- defer testutils.ClearData(testutils.DB)
+ user := testutils.SetupUserData(db)
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
- n1UUID := "4fd19336-671e-4ff3-8f22-662b80e22edc"
n1 := database.Note{
- UUID: n1UUID,
+ UUID: "4fd19336-671e-4ff3-8f22-662b80e22edc",
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n1 content",
Deleted: false,
Public: false,
}
- testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
+ testutils.MustExec(t, db.Save(&n1), "preparing n1")
nonexistentUUID := "4fd19336-671e-4ff3-8f22-662b80e22edd"
- note, ok, err := GetNote(testutils.DB, nonexistentUUID, &user)
+ note, ok, err := GetNote(db, nonexistentUUID, &user)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
diff --git a/pkg/server/permissions/permissions_test.go b/pkg/server/permissions/permissions_test.go
index 607fb2a1..4054b66f 100644
--- a/pkg/server/permissions/permissions_test.go
+++ b/pkg/server/permissions/permissions_test.go
@@ -19,7 +19,6 @@
package permissions
import (
- "os"
"testing"
"github.com/dnote/dnote/pkg/assert"
@@ -27,44 +26,38 @@ import (
"github.com/dnote/dnote/pkg/server/testutils"
)
-func TestMain(m *testing.M) {
- testutils.InitTestDB()
-
- code := m.Run()
- testutils.ClearData(testutils.DB)
-
- os.Exit(code)
-}
-
func TestViewNote(t *testing.T) {
- user := testutils.SetupUserData()
- anotherUser := testutils.SetupUserData()
+ db := testutils.InitMemoryDB(t)
- defer testutils.ClearData(testutils.DB)
+ user := testutils.SetupUserData(db)
+ anotherUser := testutils.SetupUserData(db)
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
privateNote := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "privateNote content",
Deleted: false,
Public: false,
}
- testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
+ testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "privateNote content",
Deleted: false,
Public: true,
}
- testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote")
+ testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote")
t.Run("owner accessing private note", func(t *testing.T) {
result := ViewNote(&user, privateNote)
diff --git a/pkg/server/presenters/book_test.go b/pkg/server/presenters/book_test.go
new file mode 100644
index 00000000..98155769
--- /dev/null
+++ b/pkg/server/presenters/book_test.go
@@ -0,0 +1,217 @@
+/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
+ *
+ * This file is part of Dnote.
+ *
+ * Dnote is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Dnote is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Dnote. If not, see .
+ */
+
+package presenters
+
+import (
+ "testing"
+ "time"
+
+ "github.com/dnote/dnote/pkg/assert"
+ "github.com/dnote/dnote/pkg/server/database"
+)
+
+func TestPresentBook(t *testing.T) {
+ createdAt := time.Date(2025, 1, 15, 10, 30, 45, 123456789, time.UTC)
+ updatedAt := time.Date(2025, 2, 20, 14, 45, 30, 987654321, time.UTC)
+
+ testCases := []struct {
+ name string
+ input database.Book
+ expected Book
+ }{
+ {
+ name: "basic book",
+ input: database.Book{
+ Model: database.Model{
+ ID: 1,
+ CreatedAt: createdAt,
+ UpdatedAt: updatedAt,
+ },
+ UUID: "a1b2c3d4-e5f6-4789-a012-3456789abcde",
+ UserID: 42,
+ Label: "JavaScript",
+ USN: 100,
+ },
+ expected: Book{
+ UUID: "a1b2c3d4-e5f6-4789-a012-3456789abcde",
+ USN: 100,
+ CreatedAt: FormatTS(createdAt),
+ UpdatedAt: FormatTS(updatedAt),
+ Label: "JavaScript",
+ },
+ },
+ {
+ name: "book with special characters in label",
+ input: database.Book{
+ Model: database.Model{
+ ID: 2,
+ CreatedAt: createdAt,
+ UpdatedAt: updatedAt,
+ },
+ UUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098",
+ UserID: 99,
+ Label: "C++",
+ USN: 200,
+ },
+ expected: Book{
+ UUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098",
+ USN: 200,
+ CreatedAt: FormatTS(createdAt),
+ UpdatedAt: FormatTS(updatedAt),
+ Label: "C++",
+ },
+ },
+ {
+ name: "book with empty label",
+ input: database.Book{
+ Model: database.Model{
+ ID: 3,
+ CreatedAt: createdAt,
+ UpdatedAt: updatedAt,
+ },
+ UUID: "12345678-90ab-4cde-8901-234567890abc",
+ UserID: 1,
+ Label: "",
+ USN: 0,
+ },
+ expected: Book{
+ UUID: "12345678-90ab-4cde-8901-234567890abc",
+ USN: 0,
+ CreatedAt: FormatTS(createdAt),
+ UpdatedAt: FormatTS(updatedAt),
+ Label: "",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ got := PresentBook(tc.input)
+
+ assert.Equal(t, got.UUID, tc.expected.UUID, "UUID mismatch")
+ assert.Equal(t, got.USN, tc.expected.USN, "USN mismatch")
+ assert.Equal(t, got.Label, tc.expected.Label, "Label mismatch")
+ assert.Equal(t, got.CreatedAt, tc.expected.CreatedAt, "CreatedAt mismatch")
+ assert.Equal(t, got.UpdatedAt, tc.expected.UpdatedAt, "UpdatedAt mismatch")
+ })
+ }
+}
+
+func TestPresentBooks(t *testing.T) {
+ createdAt1 := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
+ updatedAt1 := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
+ createdAt2 := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
+ updatedAt2 := time.Date(2025, 2, 2, 0, 0, 0, 0, time.UTC)
+
+ testCases := []struct {
+ name string
+ input []database.Book
+ expected []Book
+ }{
+ {
+ name: "empty slice",
+ input: []database.Book{},
+ expected: []Book{},
+ },
+ {
+ name: "single book",
+ input: []database.Book{
+ {
+ Model: database.Model{
+ ID: 1,
+ CreatedAt: createdAt1,
+ UpdatedAt: updatedAt1,
+ },
+ UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba",
+ UserID: 1,
+ Label: "Go",
+ USN: 10,
+ },
+ },
+ expected: []Book{
+ {
+ UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba",
+ USN: 10,
+ CreatedAt: FormatTS(createdAt1),
+ UpdatedAt: FormatTS(updatedAt1),
+ Label: "Go",
+ },
+ },
+ },
+ {
+ name: "multiple books",
+ input: []database.Book{
+ {
+ Model: database.Model{
+ ID: 1,
+ CreatedAt: createdAt1,
+ UpdatedAt: updatedAt1,
+ },
+ UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba",
+ UserID: 1,
+ Label: "Go",
+ USN: 10,
+ },
+ {
+ Model: database.Model{
+ ID: 2,
+ CreatedAt: createdAt2,
+ UpdatedAt: updatedAt2,
+ },
+ UUID: "abcdef01-2345-4678-9abc-def012345678",
+ UserID: 1,
+ Label: "Python",
+ USN: 20,
+ },
+ },
+ expected: []Book{
+ {
+ UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba",
+ USN: 10,
+ CreatedAt: FormatTS(createdAt1),
+ UpdatedAt: FormatTS(updatedAt1),
+ Label: "Go",
+ },
+ {
+ UUID: "abcdef01-2345-4678-9abc-def012345678",
+ USN: 20,
+ CreatedAt: FormatTS(createdAt2),
+ UpdatedAt: FormatTS(updatedAt2),
+ Label: "Python",
+ },
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ got := PresentBooks(tc.input)
+
+ assert.Equal(t, len(got), len(tc.expected), "Length mismatch")
+
+ for i := range got {
+ assert.Equal(t, got[i].UUID, tc.expected[i].UUID, "UUID mismatch")
+ assert.Equal(t, got[i].USN, tc.expected[i].USN, "USN mismatch")
+ assert.Equal(t, got[i].Label, tc.expected[i].Label, "Label mismatch")
+ assert.Equal(t, got[i].CreatedAt, tc.expected[i].CreatedAt, "CreatedAt mismatch")
+ assert.Equal(t, got[i].UpdatedAt, tc.expected[i].UpdatedAt, "UpdatedAt mismatch")
+ }
+ })
+ }
+}
diff --git a/pkg/server/presenters/email_preference.go b/pkg/server/presenters/email_preference.go
deleted file mode 100644
index acf52eed..00000000
--- a/pkg/server/presenters/email_preference.go
+++ /dev/null
@@ -1,45 +0,0 @@
-/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
- *
- * This file is part of Dnote.
- *
- * Dnote is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * Dnote is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with Dnote. If not, see .
- */
-
-package presenters
-
-import (
- "time"
-
- "github.com/dnote/dnote/pkg/server/database"
-)
-
-// EmailPreference is a presented email digest
-type EmailPreference struct {
- InactiveReminder bool `json:"inactive_reminder"`
- ProductUpdate bool `json:"product_update"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
-}
-
-// PresentEmailPreference presents a digest
-func PresentEmailPreference(p database.EmailPreference) EmailPreference {
- ret := EmailPreference{
- InactiveReminder: p.InactiveReminder,
- ProductUpdate: p.ProductUpdate,
- CreatedAt: FormatTS(p.CreatedAt),
- UpdatedAt: FormatTS(p.UpdatedAt),
- }
-
- return ret
-}
diff --git a/pkg/server/tmpl/main_test.go b/pkg/server/presenters/helpers_test.go
similarity index 71%
rename from pkg/server/tmpl/main_test.go
rename to pkg/server/presenters/helpers_test.go
index ceabf5cc..c48c90ea 100644
--- a/pkg/server/tmpl/main_test.go
+++ b/pkg/server/presenters/helpers_test.go
@@ -16,20 +16,20 @@
* along with Dnote. If not, see .
*/
-package tmpl
+package presenters
import (
- "os"
"testing"
+ "time"
- "github.com/dnote/dnote/pkg/server/testutils"
+ "github.com/dnote/dnote/pkg/assert"
)
-func TestMain(m *testing.M) {
- testutils.InitTestDB()
+func TestFormatTS(t *testing.T) {
+ input := time.Date(2025, 1, 15, 10, 30, 45, 123456789, time.UTC)
+ expected := time.Date(2025, 1, 15, 10, 30, 45, 123457000, time.UTC)
- code := m.Run()
- testutils.ClearData(testutils.DB)
+ got := FormatTS(input)
- os.Exit(code)
+ assert.Equal(t, got, expected, "FormatTS mismatch")
}
diff --git a/pkg/server/presenters/note_test.go b/pkg/server/presenters/note_test.go
new file mode 100644
index 00000000..822c5cea
--- /dev/null
+++ b/pkg/server/presenters/note_test.go
@@ -0,0 +1,127 @@
+/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
+ *
+ * This file is part of Dnote.
+ *
+ * Dnote is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * Dnote is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with Dnote. If not, see .
+ */
+
+package presenters
+
+import (
+ "testing"
+ "time"
+
+ "github.com/dnote/dnote/pkg/assert"
+ "github.com/dnote/dnote/pkg/server/database"
+)
+
+func TestPresentNote(t *testing.T) {
+ createdAt := time.Date(2025, 1, 15, 10, 30, 45, 123456789, time.UTC)
+ updatedAt := time.Date(2025, 2, 20, 14, 45, 30, 987654321, time.UTC)
+
+ input := database.Note{
+ Model: database.Model{
+ ID: 1,
+ CreatedAt: createdAt,
+ UpdatedAt: updatedAt,
+ },
+ UUID: "a1b2c3d4-e5f6-4789-a012-3456789abcde",
+ UserID: 42,
+ BookUUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098",
+ Body: "Test note content",
+ AddedOn: 1234567890,
+ Public: true,
+ USN: 100,
+ Book: database.Book{
+ UUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098",
+ Label: "JavaScript",
+ },
+ User: database.User{
+ UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba",
+ },
+ }
+
+ got := PresentNote(input)
+
+ assert.Equal(t, got.UUID, "a1b2c3d4-e5f6-4789-a012-3456789abcde", "UUID mismatch")
+ assert.Equal(t, got.Body, "Test note content", "Body mismatch")
+ assert.Equal(t, got.AddedOn, int64(1234567890), "AddedOn mismatch")
+ assert.Equal(t, got.Public, true, "Public mismatch")
+ assert.Equal(t, got.USN, 100, "USN mismatch")
+ assert.Equal(t, got.CreatedAt, FormatTS(createdAt), "CreatedAt mismatch")
+ assert.Equal(t, got.UpdatedAt, FormatTS(updatedAt), "UpdatedAt mismatch")
+ assert.Equal(t, got.Book.UUID, "f1e2d3c4-b5a6-4987-b654-321fedcba098", "Book UUID mismatch")
+ assert.Equal(t, got.Book.Label, "JavaScript", "Book Label mismatch")
+ assert.Equal(t, got.User.UUID, "9a8b7c6d-5e4f-4321-9876-543210fedcba", "User UUID mismatch")
+}
+
+func TestPresentNotes(t *testing.T) {
+ createdAt1 := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
+ updatedAt1 := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
+ createdAt2 := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
+ updatedAt2 := time.Date(2025, 2, 2, 0, 0, 0, 0, time.UTC)
+
+ input := []database.Note{
+ {
+ Model: database.Model{
+ ID: 1,
+ CreatedAt: createdAt1,
+ UpdatedAt: updatedAt1,
+ },
+ UUID: "a1b2c3d4-e5f6-4789-a012-3456789abcde",
+ UserID: 1,
+ BookUUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098",
+ Body: "First note",
+ AddedOn: 1000000000,
+ Public: false,
+ USN: 10,
+ Book: database.Book{
+ UUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098",
+ Label: "Go",
+ },
+ User: database.User{
+ UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba",
+ },
+ },
+ {
+ Model: database.Model{
+ ID: 2,
+ CreatedAt: createdAt2,
+ UpdatedAt: updatedAt2,
+ },
+ UUID: "12345678-90ab-4cde-8901-234567890abc",
+ UserID: 1,
+ BookUUID: "abcdef01-2345-4678-9abc-def012345678",
+ Body: "Second note",
+ AddedOn: 2000000000,
+ Public: true,
+ USN: 20,
+ Book: database.Book{
+ UUID: "abcdef01-2345-4678-9abc-def012345678",
+ Label: "Python",
+ },
+ User: database.User{
+ UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba",
+ },
+ },
+ }
+
+ got := PresentNotes(input)
+
+ assert.Equal(t, len(got), 2, "Length mismatch")
+ assert.Equal(t, got[0].UUID, "a1b2c3d4-e5f6-4789-a012-3456789abcde", "Note 0 UUID mismatch")
+ assert.Equal(t, got[0].Body, "First note", "Note 0 Body mismatch")
+ assert.Equal(t, got[1].UUID, "12345678-90ab-4cde-8901-234567890abc", "Note 1 UUID mismatch")
+ assert.Equal(t, got[1].Body, "Second note", "Note 1 Body mismatch")
+}
diff --git a/pkg/server/session/session.go b/pkg/server/session/session.go
index 5ad92fec..8c55549c 100644
--- a/pkg/server/session/session.go
+++ b/pkg/server/session/session.go
@@ -27,14 +27,12 @@ type Session struct {
UUID string `json:"uuid"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
- Pro bool `json:"pro"`
}
// New returns a new session for the given user
func New(user database.User, account database.Account) Session {
return Session{
UUID: user.UUID,
- Pro: user.Cloud,
Email: account.Email.String,
EmailVerified: account.EmailVerified,
}
diff --git a/pkg/server/session/session_test.go b/pkg/server/session/session_test.go
index 967053e1..107dacfe 100644
--- a/pkg/server/session/session_test.go
+++ b/pkg/server/session/session_test.go
@@ -27,36 +27,32 @@ import (
)
func TestNew(t *testing.T) {
- u1 := database.User{UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb", Cloud: true}
+ u1 := database.User{UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb"}
a1 := database.Account{Email: database.ToNullString("alice@example.com"), EmailVerified: false}
- u2 := database.User{UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e", Cloud: false}
+ u2 := database.User{UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e"}
a2 := database.Account{Email: database.ToNullString("bob@example.com"), EmailVerified: false}
testCases := []struct {
- user database.User
- account database.Account
- expectedPro bool
+ user database.User
+ account database.Account
}{
{
- user: u1,
- account: a1,
- expectedPro: true,
+ user: u1,
+ account: a1,
},
{
- user: u2,
- account: a2,
- expectedPro: false,
+ user: u2,
+ account: a2,
},
}
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("user pro %t", tc.expectedPro), func(t *testing.T) {
+ for idx, tc := range testCases {
+ t.Run(fmt.Sprintf("user %d", idx), func(t *testing.T) {
// Execute
got := New(tc.user, tc.account)
expected := Session{
UUID: tc.user.UUID,
- Pro: tc.expectedPro,
Email: tc.account.Email.String,
EmailVerified: tc.account.EmailVerified,
}
diff --git a/pkg/server/testutils/main.go b/pkg/server/testutils/main.go
index 33a4a00d..08bc07de 100644
--- a/pkg/server/testutils/main.go
+++ b/pkg/server/testutils/main.go
@@ -27,36 +27,44 @@ import (
"net/http"
"net/url"
"reflect"
- // "strconv"
"strings"
"sync"
"testing"
"time"
- "github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
- "gorm.io/gorm"
+ "github.com/dnote/dnote/pkg/server/helpers"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
+ "gorm.io/driver/sqlite"
+ "gorm.io/gorm"
)
-func init() {
- rand.Seed(time.Now().UnixNano())
+// InitDB opens a database at the given path and initializes the schema
+func InitDB(dbPath string) *gorm.DB {
+ db := database.Open(dbPath)
+ database.InitSchema(db)
+ database.Migrate(db)
+ return db
}
-// DB is the database connection to a test database
-var DB *gorm.DB
-
-// InitTestDB establishes connection pool with the test database specified by
-// the environment variable configuration and initalizes a new schema
-func InitTestDB() {
- c := config.Load()
- fmt.Println(c.DB.GetConnectionStr())
- db := database.Open(c)
+// InitMemoryDB creates an in-memory SQLite database with the schema initialized
+func InitMemoryDB(t *testing.T) *gorm.DB {
+ // Use file-based in-memory database with unique UUID per test to avoid sharing
+ uuid, err := helpers.GenUUID()
+ if err != nil {
+ t.Fatalf("failed to generate UUID for test database: %v", err)
+ }
+ dbName := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid)
+ db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{})
+ if err != nil {
+ t.Fatalf("failed to open in-memory database: %v", err)
+ }
database.InitSchema(db)
+ database.Migrate(db)
- DB = db
+ return db
}
// ClearData deletes all records from the database
@@ -68,15 +76,9 @@ func ClearData(db *gorm.DB) {
if err := db.Where("1 = 1").Delete(&database.Book{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear books"))
}
- if err := db.Where("1 = 1").Delete(&database.Notification{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear notifications"))
- }
if err := db.Where("1 = 1").Delete(&database.Token{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear tokens"))
}
- if err := db.Where("1 = 1").Delete(&database.EmailPreference{}).Error; err != nil {
- panic(errors.Wrap(err, "Failed to clear email preferences"))
- }
if err := db.Where("1 = 1").Delete(&database.Session{}).Error; err != nil {
panic(errors.Wrap(err, "Failed to clear sessions"))
}
@@ -88,13 +90,27 @@ func ClearData(db *gorm.DB) {
}
}
+// MustUUID generates a UUID and fails the test on error
+func MustUUID(t *testing.T) string {
+ uuid, err := helpers.GenUUID()
+ if err != nil {
+ t.Fatal(errors.Wrap(err, "Failed to generate UUID"))
+ }
+ return uuid
+}
+
// SetupUserData creates and returns a new user for testing purposes
-func SetupUserData() database.User {
- user := database.User{
- Cloud: true,
+func SetupUserData(db *gorm.DB) database.User {
+ uuid, err := helpers.GenUUID()
+ if err != nil {
+ panic(errors.Wrap(err, "Failed to generate UUID"))
}
- if err := DB.Save(&user).Error; err != nil {
+ user := database.User{
+ UUID: uuid,
+ }
+
+ if err := db.Save(&user).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare user"))
}
@@ -102,7 +118,7 @@ func SetupUserData() database.User {
}
// SetupAccountData creates and returns a new account for the user
-func SetupAccountData(user database.User, email, password string) database.Account {
+func SetupAccountData(db *gorm.DB, user database.User, email, password string) database.Account {
account := database.Account{
UserID: user.ID,
}
@@ -116,7 +132,7 @@ func SetupAccountData(user database.User, email, password string) database.Accou
}
account.Password = database.ToNullString(string(hashedPassword))
- if err := DB.Save(&account).Error; err != nil {
+ if err := db.Save(&account).Error; err != nil {
panic(errors.Wrap(err, "Failed to prepare account"))
}
@@ -124,33 +140,19 @@ func SetupAccountData(user database.User, email, password string) database.Accou
}
// SetupSession creates and returns a new user session
-func SetupSession(t *testing.T, user database.User) database.Session {
+func SetupSession(db *gorm.DB, user database.User) database.Session {
session := database.Session{
Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=",
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 24),
}
- if err := DB.Save(&session).Error; err != nil {
- t.Fatal(errors.Wrap(err, "Failed to prepare user"))
+ if err := db.Save(&session).Error; err != nil {
+ panic(errors.Wrap(err, "Failed to prepare user"))
}
return session
}
-// SetupEmailPreferenceData creates and returns a new email frequency for a user
-func SetupEmailPreferenceData(user database.User, inactiveReminder bool) database.EmailPreference {
- frequency := database.EmailPreference{
- UserID: user.ID,
- InactiveReminder: inactiveReminder,
- }
-
- if err := DB.Save(&frequency).Error; err != nil {
- panic(errors.Wrap(err, "Failed to prepare email frequency"))
- }
-
- return frequency
-}
-
// HTTPDo makes an HTTP request and returns a response
func HTTPDo(t *testing.T, req *http.Request) *http.Response {
hc := http.Client{
@@ -170,8 +172,8 @@ func HTTPDo(t *testing.T, req *http.Request) *http.Response {
return res
}
-// SetReqAuthHeader sets the authorization header in the given request for the given user
-func SetReqAuthHeader(t *testing.T, req *http.Request, user database.User) {
+// SetReqAuthHeader sets the authorization header in the given request for the given user with a specific DB
+func SetReqAuthHeader(t *testing.T, db *gorm.DB, req *http.Request, user database.User) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
t.Fatal(errors.Wrap(err, "reading random bits"))
@@ -182,19 +184,18 @@ func SetReqAuthHeader(t *testing.T, req *http.Request, user database.User) {
UserID: user.ID,
ExpiresAt: time.Now().Add(time.Hour * 10 * 24),
}
- if err := DB.Save(&session).Error; err != nil {
+ if err := db.Save(&session).Error; err != nil {
t.Fatal(errors.Wrap(err, "Failed to prepare user"))
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.Key))
}
-// HTTPAuthDo makes an HTTP request with an appropriate authorization header for a user
-func HTTPAuthDo(t *testing.T, req *http.Request, user database.User) *http.Response {
- SetReqAuthHeader(t, req, user)
+// HTTPAuthDo makes an HTTP request with an appropriate authorization header for a user with a specific DB
+func HTTPAuthDo(t *testing.T, db *gorm.DB, req *http.Request, user database.User) *http.Response {
+ SetReqAuthHeader(t, db, req, user)
return HTTPDo(t, req)
-
}
// MakeReq makes an HTTP request and returns a response
diff --git a/pkg/server/tmpl/app_test.go b/pkg/server/tmpl/app_test.go
index f7042574..fba9bed4 100644
--- a/pkg/server/tmpl/app_test.go
+++ b/pkg/server/tmpl/app_test.go
@@ -31,7 +31,9 @@ import (
func TestAppShellExecute(t *testing.T) {
t.Run("home", func(t *testing.T) {
- a, err := NewAppShell(testutils.DB, []byte("
{{ .Title }}{{ .MetaTags }}"))
+ db := testutils.InitMemoryDB(t)
+
+ a, err := NewAppShell(db, []byte("{{ .Title }}{{ .MetaTags }}"))
if err != nil {
t.Fatal(errors.Wrap(err, "preparing app shell"))
}
@@ -50,23 +52,25 @@ func TestAppShellExecute(t *testing.T) {
})
t.Run("note", func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
- user := testutils.SetupUserData()
+ user := testutils.SetupUserData(db)
b1 := database.Book{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
- testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
+ testutils.MustExec(t, db.Save(&b1), "preparing b1")
n1 := database.Note{
+ UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Public: true,
Body: "n1 content",
}
- testutils.MustExec(t, testutils.DB.Save(&n1), "preparing note")
+ testutils.MustExec(t, db.Save(&n1), "preparing note")
- a, err := NewAppShell(testutils.DB, []byte("{{ .MetaTags }}"))
+ a, err := NewAppShell(db, []byte("{{ .MetaTags }}"))
if err != nil {
t.Fatal(errors.Wrap(err, "preparing app shell"))
}
diff --git a/pkg/server/tmpl/data_test.go b/pkg/server/tmpl/data_test.go
index 8c60e3d3..c072d12e 100644
--- a/pkg/server/tmpl/data_test.go
+++ b/pkg/server/tmpl/data_test.go
@@ -42,7 +42,8 @@ func TestNotePageGetData(t *testing.T) {
// Set time.Local to UTC for deterministic test
time.Local = time.UTC
- a, err := NewAppShell(testutils.DB, nil)
+ db := testutils.InitMemoryDB(t)
+ a, err := NewAppShell(db, nil)
if err != nil {
t.Fatal(errors.Wrap(err, "preparing app shell"))
}
diff --git a/pkg/server/token/main_test.go b/pkg/server/token/main_test.go
deleted file mode 100644
index 84e4f0b2..00000000
--- a/pkg/server/token/main_test.go
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
- *
- * This file is part of Dnote.
- *
- * Dnote is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * Dnote is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with Dnote. If not, see .
- */
-
-package token
-
-import (
- "os"
- "testing"
-
- "github.com/dnote/dnote/pkg/server/testutils"
-)
-
-func TestMain(m *testing.M) {
- testutils.InitTestDB()
-
- code := m.Run()
- testutils.ClearData(testutils.DB)
-
- os.Exit(code)
-}
diff --git a/pkg/server/token/token_test.go b/pkg/server/token/token_test.go
index 371ab71c..922cc93d 100644
--- a/pkg/server/token/token_test.go
+++ b/pkg/server/token/token_test.go
@@ -33,30 +33,30 @@ func TestCreate(t *testing.T) {
kind string
}{
{
- kind: database.TokenTypeEmailPreference,
+ kind: database.TokenTypeEmailVerification,
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("token type %s", tc.kind), func(t *testing.T) {
- defer testutils.ClearData(testutils.DB)
+ db := testutils.InitMemoryDB(t)
// Set up
- u := testutils.SetupUserData()
+ u := testutils.SetupUserData(db)
// Execute
- tok, err := Create(testutils.DB, u.ID, tc.kind)
+ tok, err := Create(db, u.ID, tc.kind)
if err != nil {
t.Fatal(errors.Wrap(err, "performing"))
}
// Test
var count int64
- testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&count), "counting token")
+ testutils.MustExec(t, db.Model(&database.Token{}).Count(&count), "counting token")
assert.Equalf(t, count, int64(1), "error mismatch")
var tokenRecord database.Token
- testutils.MustExec(t, testutils.DB.First(&tokenRecord), "finding token")
+ testutils.MustExec(t, db.First(&tokenRecord), "finding token")
assert.Equalf(t, tokenRecord.UserID, tok.UserID, "UserID mismatch")
assert.Equalf(t, tokenRecord.Value, tok.Value, "Value mismatch")
assert.Equalf(t, tokenRecord.Type, tok.Type, "Type mismatch")
diff --git a/pkg/server/views/helpers.go b/pkg/server/views/helpers.go
index 867033af..ac4546f7 100644
--- a/pkg/server/views/helpers.go
+++ b/pkg/server/views/helpers.go
@@ -50,7 +50,7 @@ func initHelpers(c Config, a *app.App) template.FuncMap {
"defaultValue": ctx.defaultValue,
"add": ctx.add,
"assetBaseURL": func() string {
- return a.Config.AssetBaseURL
+ return a.AssetBaseURL
},
}
diff --git a/pkg/server/views/templates/users/settings_about.gohtml b/pkg/server/views/templates/users/settings_about.gohtml
index bba9a8f0..3252b9de 100644
--- a/pkg/server/views/templates/users/settings_about.gohtml
+++ b/pkg/server/views/templates/users/settings_about.gohtml
@@ -27,27 +27,6 @@
- {{if ne .Standalone "true"}}
-
- {{else}}
-
- {{end}}
diff --git a/pkg/server/views/view.go b/pkg/server/views/view.go
index 75a9aa92..2b0e57a9 100644
--- a/pkg/server/views/view.go
+++ b/pkg/server/views/view.go
@@ -119,9 +119,6 @@ func (v *View) Render(w http.ResponseWriter, r *http.Request, data *Data, status
vd.Yield["EmailVerified"] = vd.Account.EmailVerified
vd.Yield["EmailVerified"] = vd.Account.EmailVerified
}
- if vd.User != nil {
- vd.Yield["Cloud"] = vd.User.Cloud
- }
vd.Yield["CurrentPath"] = r.URL.Path
vd.Yield["Standalone"] = buildinfo.Standalone
diff --git a/scripts/server/dev.sh b/scripts/server/dev.sh
index 11f1e207..5eb93328 100755
--- a/scripts/server/dev.sh
+++ b/scripts/server/dev.sh
@@ -23,7 +23,7 @@ cp "$basePath"/pkg/server/assets/static/* "$basePath/pkg/server/static"
# run server
moduleName="github.com/dnote/dnote"
ldflags="-X '$moduleName/pkg/server/buildinfo.CSSFiles=main.css' -X '$moduleName/pkg/server/buildinfo.JSFiles=main.js' -X '$moduleName/pkg/server/buildinfo.Version=dev' -X '$moduleName/pkg/server/buildinfo.Standalone=true'"
-task="go run -ldflags \"$ldflags\" main.go start -port 3000"
+task="go run -ldflags \"$ldflags\" --tags fts5 main.go start -port 3000"
(
cd "$basePath/pkg/watcher" && \
diff --git a/scripts/server/test.sh b/scripts/server/test.sh
index ed1d0a41..10207010 100755
--- a/scripts/server/test.sh
+++ b/scripts/server/test.sh
@@ -8,9 +8,9 @@ pushd "$dir/../../pkg/server"
function run_test {
if [ -z "$1" ]; then
- go test ./... -cover -p 1
+ go test -tags "fts5" ./... -cover
else
- go test -run "$1" -cover -p 1
+ go test -tags "fts5" -run "$1" -cover
fi
}