Use SQLite on the server (#681)

* Use SQLite on server

* Remove pro

* Simplify

* Use flag

* Automate release
This commit is contained in:
Sung 2025-10-05 17:02:30 -07:00 committed by GitHub
commit 61162e2add
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
127 changed files with 3330 additions and 3514 deletions

View file

@ -11,25 +11,6 @@ jobs:
build:
runs-on: ubuntu-22.04
services:
postgres:
image: postgres:14
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dnote_test
POSTGRES_PORT: 5432
# Wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
# Expose port to the host
ports:
- 5432:5432
steps:
- uses: actions/checkout@v5
- uses: actions/setup-go@v6

91
.github/workflows/release-server.yml vendored Normal file
View file

@ -0,0 +1,91 @@
name: Release Server
on:
push:
tags:
- 'server-v*'
jobs:
release:
runs-on: ubuntu-22.04
permissions:
contents: write
steps:
- uses: actions/checkout@v5
- uses: actions/setup-go@v6
with:
go-version: '>=1.25.0'
- uses: actions/setup-node@v4
with:
node-version: '20'
- name: Extract version from tag
id: version
run: |
TAG=${GITHUB_REF#refs/tags/server-v}
echo "version=$TAG" >> $GITHUB_OUTPUT
echo "Releasing version: $TAG"
- name: Install dependencies
run: make install
- name: Run tests
run: make test
- name: Build server
run: make version=${{ steps.version.outputs.version }} build-server
- name: Prepare Docker build context
run: |
VERSION="${{ steps.version.outputs.version }}"
cp build/server/dnote_server_${VERSION}_linux_amd64.tar.gz host/docker/
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
context: ./host/docker
push: true
tags: |
dnote/dnote:${{ steps.version.outputs.version }}
dnote/dnote:latest
build-args: |
tarballName=dnote_server_${{ steps.version.outputs.version }}_linux_amd64.tar.gz
- name: Create GitHub release
env:
GH_TOKEN: ${{ github.token }}
run: |
VERSION="${{ steps.version.outputs.version }}"
TAG="server-v${VERSION}"
# Determine if prerelease (version not matching major.minor.patch)
FLAGS=""
if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
FLAGS="--prerelease"
fi
gh release create "$TAG" \
build/server/*.tar.gz \
build/server/*_checksums.txt \
$FLAGS \
--title="$TAG" \
--notes="Please see the [CHANGELOG](https://github.com/dnote/dnote/blob/master/CHANGELOG.md)" \
--draft
- name: Push to Docker Hub
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: |
VERSION="${{ steps.version.outputs.version }}"
echo "$DOCKER_TOKEN" | docker login -u "$DOCKER_USERNAME" --password-stdin
docker push dnote/dnote:${VERSION}
docker push dnote/dnote:latest

2
.gitignore vendored
View file

@ -5,3 +5,5 @@
node_modules
/test
tmp
*.db
server

View file

@ -48,12 +48,6 @@ test-e2e:
@(${currentDir}/scripts/e2e/test.sh)
.PHONY: test-e2e
test-selfhost:
@echo "==> running a smoke test for self-hosting"
@${currentDir}/host/smoketest/run_test.sh ${tarballPath}
.PHONY: test-selfhost
# development
dev-server:
@echo "==> running dev environment"

View file

@ -4,48 +4,33 @@ This guide documents the steps for installing the Dnote server on your own machi
## Overview
Dnote server comes as a single binary file that you can simply download and run. It uses Postgres as the database.
Dnote server comes as a single binary file that you can simply download and run. It uses SQLite as the database.
## Installation
1. Install Postgres 11+.
2. Create a `dnote` database by running `createdb dnote`
3. Download the official Dnote server release from the [release page](https://github.com/dnote/dnote/releases).
4. Extract the archive and move the `dnote-server` executable to `/usr/local/bin`.
1. Download the official Dnote server release from the [release page](https://github.com/dnote/dnote/releases).
2. Extract the archive and move the `dnote-server` executable to `/usr/local/bin`.
```bash
tar -xzf dnote-server-$version-$os.tar.gz
mv ./dnote-server /usr/local/bin
```
4. Run Dnote
3. Run Dnote
```bash
GO_ENV=PRODUCTION \
OnPremises=true \
DBHost=localhost \
DBPort=5432 \
DBName=dnote \
DBUser=$user \
DBPassword=$password \
WebURL=$webURL \
SmtpHost=$SmtpHost \
SmtpPort=$SmtpPort \
SmtpUsername=$SmtpUsername \
SmtpPassword=$SmtpPassword \
DisableRegistration=false \
dnote-server start
dnote-server start --webUrl=$webURL
```
Replace `$user`, `$password` with the credentials of the Postgres user that owns the `dnote` database.
Replace `$webURL` with the full URL to your server, without a trailing slash (e.g. `https://your.server`).
Replace `$SmtpHost`, `SmtpPort`, `$SmtpUsername`, `$SmtpPassword` with actual values, if you would like to receive spaced repetition through email.
Additional flags:
- `--port`: Server port (default: `3000`)
- `--disableRegistration`: Disable user registration (default: `false`)
- `--logLevel`: Log level: `debug`, `info`, `warn`, or `error` (default: `info`)
- `--appEnv`: environment (default: `PRODUCTION`)
Replace `DisableRegistration` to `true` if you would like to disable user registrations.
By default, dnote server will run on the port 3000.
You can also use environment variables: `PORT`, `WebURL`, `DisableRegistration`, `LOG_LEVEL`, `APP_ENV`.
## Configuration
@ -127,33 +112,31 @@ User=$user
Restart=always
RestartSec=3
WorkingDirectory=/home/$user
ExecStart=/usr/local/bin/dnote-server start
Environment=GO_ENV=PRODUCTION
Environment=OnPremises=true
Environment=DBHost=localhost
Environment=DBPort=5432
Environment=DBName=dnote
Environment=DBUser=$DBUser
Environment=DBPassword=$DBPassword
Environment=DBSkipSSL=true
Environment=WebURL=$WebURL
Environment=SmtpHost=
Environment=SmtpPort=
Environment=SmtpUsername=
Environment=SmtpPassword=
ExecStart=/usr/local/bin/dnote-server start --webUrl=$WebURL
[Install]
WantedBy=multi-user.target
```
Replace `$user`, `$WebURL`, `$DBUser`, and `$DBPassword` with the actual values.
Replace `$user` and `$WebURL` with the actual values.
Optionally, if you would like to send spaced repetitions throught email, populate `SmtpHost`, `SmtpPort`, `SmtpUsername`, and `SmtpPassword`.
By default, the database will be stored at `$XDG_DATA_HOME/dnote/server.db` (typically `~/.local/share/dnote/server.db`). To use a custom location, add `--dbPath=/path/to/database.db` to the `ExecStart` command.
2. Reload the change by running `sudo systemctl daemon-reload`.
3. Enable the Daemon by running `sudo systemctl enable dnote`.`
4. Start the Daemon by running `sudo systemctl start dnote`
### Optional: Email Support
To enable sending emails, add the following environment variables to your configuration. But they are not required.
- `SmtpHost` - SMTP server hostname
- `SmtpPort` - SMTP server port
- `SmtpUsername` - SMTP username
- `SmtpPassword` - SMTP password
For systemd, add these as additional `Environment=` lines in `/etc/systemd/system/dnote.service`.
### Configure clients
Let's configure Dnote clients to connect to the self-hosted web API endpoint.
@ -166,7 +149,7 @@ The following is an example configuration:
```yaml
editor: nvim
apiEndpoint: https://api.getdnote.com
apiEndpoint: https://localhost:3000/api
```
Simply change the value for `apiEndpoint` to a full URL to the self-hosted instance, followed by '/api', and save the configuration file.
@ -177,7 +160,3 @@ e.g.
editor: nvim
apiEndpoint: my-dnote-server.com/api
```
#### Browser extension
Navigate into the 'Settings' tab and set the values for 'API URL', and 'Web URL'.

19
go.mod
View file

@ -3,7 +3,6 @@ module github.com/dnote/dnote
go 1.25
require (
github.com/aymerick/douceur v0.2.0
github.com/dnote/actions v0.2.0
github.com/fatih/color v1.18.0
github.com/google/go-cmp v0.7.0
@ -12,42 +11,34 @@ require (
github.com/gorilla/csrf v1.7.3
github.com/gorilla/mux v1.8.1
github.com/gorilla/schema v1.4.1
github.com/joho/godotenv v1.5.1
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.32
github.com/pkg/errors v0.9.1
github.com/radovskyb/watcher v1.0.7
github.com/robfig/cron v1.2.0
github.com/rubenv/sql-migrate v1.8.0
github.com/sergi/go-diff v1.3.1
github.com/spf13/cobra v1.10.1
golang.org/x/crypto v0.42.0
golang.org/x/time v0.13.0
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
gopkg.in/yaml.v2 v2.4.0
gorm.io/driver/postgres v1.5.7
gorm.io/gorm v1.25.7
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.30.0
)
require (
github.com/PuerkitoBio/goquery v1.10.3 // indirect
github.com/andybalholm/cascadia v1.3.3 // indirect
github.com/go-gorp/gorp/v3 v3.1.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/spf13/pflag v1.0.10 // indirect
golang.org/x/net v0.44.0 // indirect
github.com/stretchr/testify v1.8.1 // indirect
golang.org/x/sys v0.36.0 // indirect
golang.org/x/term v0.35.0 // indirect
golang.org/x/text v0.29.0 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
)

110
go.sum
View file

@ -1,10 +1,5 @@
github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM=
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -12,12 +7,7 @@ github.com/dnote/actions v0.2.0 h1:P1ut2/QRKwfAzIIB374vN9A4IanU94C/payEocvngYo=
github.com/dnote/actions v0.2.0/go.mod h1:bBIassLhppVQdbC3iaE92SHBpM1HOVe+xZoAlj9ROxw=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs=
github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=
@ -30,8 +20,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0=
github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E=
@ -40,47 +28,35 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/poy/onpar v1.1.2 h1:QaNrNiZx0+Nar5dLgTVp5mXkyoVFIbepjyEoGSnhbAY=
github.com/poy/onpar v1.1.2/go.mod h1:6X8FLNoxyr9kkmnlqpK6LSoiOtrO6MICtWwEuWkLjzg=
github.com/radovskyb/watcher v1.0.7 h1:AYePLih6dpmS32vlHfhCeli8127LzkIgwJGcwwe8tUE=
github.com/radovskyb/watcher v1.0.7/go.mod h1:78okwvY5wPdzcb1UYnip1pvrZNIVEIh/Cm+ZuvsUYIg=
github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ=
github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rubenv/sql-migrate v1.8.0 h1:dXnYiJk9k3wetp7GfQbKJcPHjVJL6YK19tKj8t2Ns0o=
github.com/rubenv/sql-migrate v1.8.0/go.mod h1:F2bGFBwCU+pnmbtNYDeKvSuvL6lBVtXDXUUv5t+u1qw=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
@ -90,88 +66,24 @@ github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
@ -187,7 +99,7 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=

View file

@ -1,35 +1,14 @@
version: "3"
services:
postgres:
image: postgres:14-alpine
environment:
POSTGRES_USER: dnote
POSTGRES_PASSWORD: dnote
POSTGRES_DB: dnote
volumes:
- ./dnote_data:/var/lib/postgresql/data
restart: always
dnote:
image: dnote/dnote:latest
environment:
GO_ENV: PRODUCTION
DBSkipSSL: "true"
DBHost: postgres
DBPort: 5432
DBName: dnote
DBUser: dnote
DBPassword: dnote
APP_ENV: PRODUCTION
WebURL: localhost:3000
OnPremises: "true"
SmtpHost:
SmtpPort:
SmtpUsername:
SmtpPassword:
DisableRegistration: "false"
ports:
- 3000:3000
depends_on:
- postgres
volumes:
- ./dnote_data:/data
restart: always

View file

@ -1,25 +1,6 @@
#!/bin/sh
wait_for_db() {
HOST=${DBHost:-postgres}
PORT=${DBPort:-5432}
echo "Waiting for the database connection..."
attempts=0
max_attempts=10
while [ $attempts -lt $max_attempts ]; do
nc -z "${HOST}" "${PORT}" 2>/dev/null && break
echo "Waiting for db at ${HOST}:${PORT}..."
sleep 5
attempts=$((attempts+1))
done
if [ $attempts -eq $max_attempts ]; then
echo "Timed out while waiting for db at ${HOST}:${PORT}"
exit 1
fi
}
wait_for_db
# Set default DBPath to /data if not specified
export DBPath=${DBPath:-/data/dnote.db}
exec "$@"

View file

@ -1 +0,0 @@
/volume

View file

@ -1,9 +0,0 @@
This directory contains a smoke test for running a self-hosted instance using a virtual machine.
## Instruction
The following script will set up a test environment in Vagrant and run the test.
```
./run_test.sh
```

View file

@ -1,9 +0,0 @@
# -*- mode: ruby -*-
Vagrant.configure("2") do |config|
config.vm.box = "ubuntu/jammy64"
config.vm.synced_folder './volume', '/vagrant'
config.vm.network "forwarded_port", guest: 2300, host: 2300
config.vm.provision 'shell', path: './setup.sh', privileged: false
end

View file

@ -1,42 +0,0 @@
#!/usr/bin/env bash
# run_test.sh builds a fresh server image, and mounts it on a fresh
# virtual machine and runs a smoke test. If a tarball path is not provided,
# this script builds a new version and uses it.
set -ex
# tarballPath is an absolute path to a release tarball containing the dnote server.
tarballPath=$1
dir=$(dirname "${BASH_SOURCE[0]}")
projectDir="$dir/../.."
# build
if [ -z "$tarballPath" ]; then
pushd "$projectDir"
make version=integration_test build-server
popd
tarballPath="$projectDir/build/server/dnote_server_integration_test_linux_amd64.tar.gz"
fi
pushd "$dir"
# start a virtual machine
volume="$dir/volume"
rm -rf "$volume"
mkdir -p "$volume"
cp "$tarballPath" "$volume"
cp "$dir/testsuite.sh" "$volume"
vagrant up
# run tests
set +e
if ! vagrant ssh -c "/vagrant/testsuite.sh"; then
echo "Test failed. Please see the output."
vagrant halt
exit 1
fi
set -e
vagrant halt
popd

View file

@ -1,21 +0,0 @@
#!/usr/bin/env bash
set -ex
sudo apt-get install wget ca-certificates
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add -
sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt/ `lsb_release -cs`-pgdg main" >> /etc/apt/sources.list.d/pgdg.list'
sudo apt-get update
sudo apt-get install -y postgresql-14
# set up database
sudo usermod -a -G sudo postgres
cd /var/lib/postgresql
sudo -u postgres createdb dnote
sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'postgres';"
# allow connection from host and allow to connect without password
sudo sed -i "/port*/a listen_addresses = '*'" /etc/postgresql/14/main/postgresql.conf
sudo sed -i 's/host.*all.*.all.*md5/# &/' /etc/postgresql/14/main/pg_hba.conf
sudo sed -i "$ a host all all all trust" /etc/postgresql/14/main/pg_hba.conf
sudo service postgresql restart

View file

@ -1,45 +0,0 @@
#!/usr/bin/env bash
# testsuite.sh runs the smoke tests for a self-hosted instance.
# It is meant to be run inside a virtual machine which has been
# set up by an entry script.
set -eux
echo 'Running a smoke test'
cd /var/lib/postgresql
sudo -u postgres dropdb dnote
sudo -u postgres createdb dnote
cd /vagrant
tar -xvf dnote_server_integration_test_linux_amd64.tar.gz
GO_ENV=PRODUCTION \
DBHost=localhost \
DBPort=5432 \
DBName=dnote \
DBUser=postgres \
DBPassword=postgres \
WebURL=localhost:3000 \
./dnote-server -port 2300 start & sleep 3
assert_http_status() {
url=$1
expected=$2
echo "======== [TEST CASE] asserting response status code for $url ========"
got=$(curl --write-out %"{http_code}" --silent --output /dev/null "$url")
if [ "$got" != "$expected" ]; then
echo "======== ASSERTION FAILED ========"
echo "status code for $url: expected: $expected got: $got"
echo "=================================="
exit 1
fi
}
assert_http_status http://localhost:2300 "302"
assert_http_status http://localhost:2300/health "200"
echo "======== [SUCCESS] TEST PASSED! ========"

View file

@ -22,7 +22,7 @@ package assert
import (
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"reflect"
"runtime/debug"
@ -138,7 +138,7 @@ func EqualJSON(t *testing.T, a, b, message string) {
// expected
func StatusCodeEquals(t *testing.T, res *http.Response, expected int, message string) {
if res.StatusCode != expected {
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}

View file

@ -23,7 +23,7 @@ package client
import (
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/url"
"strconv"
@ -95,7 +95,7 @@ func checkRespErr(res *http.Response) error {
return nil
}
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
return errors.Wrapf(err, "server responded with %d but client could not read the response body", res.StatusCode)
}
@ -169,7 +169,7 @@ func GetSyncState(ctx context.DnoteCtx) (GetSyncStateResp, error) {
return ret, errors.Wrap(err, "constructing http request")
}
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
return ret, errors.Wrap(err, "reading the response body")
}
@ -233,7 +233,7 @@ func GetSyncFragment(ctx context.DnoteCtx, afterUSN int) (GetSyncFragmentResp, e
path := fmt.Sprintf("/v3/sync/fragment?%s", queryStr)
res, err := doAuthorizedReq(ctx, "GET", path, "", nil)
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
return GetSyncFragmentResp{}, errors.Wrap(err, "reading the response body")
}

View file

@ -20,7 +20,7 @@ package edit
import (
"database/sql"
"io/ioutil"
"os"
"strconv"
"github.com/dnote/dnote/pkg/cli/context"
@ -45,7 +45,7 @@ func waitEditorNoteContent(ctx context.DnoteCtx, note database.Note) (string, er
return "", errors.Wrap(err, "getting temporarily content file path")
}
if err := ioutil.WriteFile(fpath, []byte(note.Body), 0644); err != nil {
if err := os.WriteFile(fpath, []byte(note.Body), 0644); err != nil {
return "", errors.Wrap(err, "preparing tmp content file")
}

View file

@ -20,7 +20,7 @@ package config
import (
"fmt"
"io/ioutil"
"os"
"github.com/dnote/dnote/pkg/cli/consts"
"github.com/dnote/dnote/pkg/cli/context"
@ -66,7 +66,7 @@ func Read(ctx context.DnoteCtx) (Config, error) {
var ret Config
configPath := GetPath(ctx)
b, err := ioutil.ReadFile(configPath)
b, err := os.ReadFile(configPath)
if err != nil {
return ret, errors.Wrap(err, "reading config file")
}
@ -88,7 +88,7 @@ func Write(ctx context.DnoteCtx, cf Config) error {
return errors.Wrap(err, "marshalling config into YAML")
}
err = ioutil.WriteFile(path, b, 0644)
err = os.WriteFile(path, b, 0644)
if err != nil {
return errors.Wrap(err, "writing the config file")
}

View file

@ -32,11 +32,11 @@ import (
"github.com/dnote/dnote/pkg/cli/consts"
"github.com/dnote/dnote/pkg/cli/context"
"github.com/dnote/dnote/pkg/cli/database"
"github.com/dnote/dnote/pkg/cli/dirs"
"github.com/dnote/dnote/pkg/cli/log"
"github.com/dnote/dnote/pkg/cli/migrate"
"github.com/dnote/dnote/pkg/cli/utils"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/dirs"
"github.com/pkg/errors"
"github.com/spf13/cobra"
)

View file

@ -23,7 +23,6 @@ package migrate
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"time"
@ -232,7 +231,7 @@ func readSchema(ctx context.DnoteCtx) (schema, error) {
path := getSchemaPath(ctx)
b, err := ioutil.ReadFile(path)
b, err := os.ReadFile(path)
if err != nil {
return ret, errors.Wrap(err, "Failed to read schema file")
}
@ -252,7 +251,7 @@ func writeSchema(ctx context.DnoteCtx, s schema) error {
return errors.Wrap(err, "Failed to marshal schema into yaml")
}
if err := ioutil.WriteFile(path, d, 0644); err != nil {
if err := os.WriteFile(path, d, 0644); err != nil {
return errors.Wrap(err, "Failed to write schema file")
}
@ -504,7 +503,7 @@ func migrateToV1(ctx context.DnoteCtx) error {
func migrateToV2(ctx context.DnoteCtx) error {
notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote)
b, err := ioutil.ReadFile(notePath)
b, err := os.ReadFile(notePath)
if err != nil {
return errors.Wrap(err, "Failed to read the note file")
}
@ -548,7 +547,7 @@ func migrateToV2(ctx context.DnoteCtx) error {
return errors.Wrap(err, "Failed to marshal new dnote into JSON")
}
err = ioutil.WriteFile(notePath, d, 0644)
err = os.WriteFile(notePath, d, 0644)
if err != nil {
return errors.Wrap(err, "Failed to write the new dnote into the file")
}
@ -561,7 +560,7 @@ func migrateToV3(ctx context.DnoteCtx) error {
notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote)
actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote)
b, err := ioutil.ReadFile(notePath)
b, err := os.ReadFile(notePath)
if err != nil {
return errors.Wrap(err, "Failed to read the note file")
}
@ -615,7 +614,7 @@ func migrateToV3(ctx context.DnoteCtx) error {
return errors.Wrap(err, "Failed to marshal actions into JSON")
}
err = ioutil.WriteFile(actionsPath, a, 0644)
err = os.WriteFile(actionsPath, a, 0644)
if err != nil {
return errors.Wrap(err, "Failed to write the actions into a file")
}
@ -647,7 +646,7 @@ func getEditorCommand() string {
func migrateToV4(ctx context.DnoteCtx) error {
configPath := fmt.Sprintf("%s/dnoterc", ctx.Paths.LegacyDnote)
b, err := ioutil.ReadFile(configPath)
b, err := os.ReadFile(configPath)
if err != nil {
return errors.Wrap(err, "Failed to read the config file")
}
@ -668,7 +667,7 @@ func migrateToV4(ctx context.DnoteCtx) error {
return errors.Wrap(err, "Failed to marshal config into JSON")
}
err = ioutil.WriteFile(configPath, data, 0644)
err = os.WriteFile(configPath, data, 0644)
if err != nil {
return errors.Wrap(err, "Failed to write the config into a file")
}
@ -680,7 +679,7 @@ func migrateToV4(ctx context.DnoteCtx) error {
func migrateToV5(ctx context.DnoteCtx) error {
actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote)
b, err := ioutil.ReadFile(actionsPath)
b, err := os.ReadFile(actionsPath)
if err != nil {
return errors.Wrap(err, "reading the actions file")
}
@ -738,7 +737,7 @@ func migrateToV5(ctx context.DnoteCtx) error {
if err != nil {
return errors.Wrap(err, "marshalling result into JSON")
}
err = ioutil.WriteFile(actionsPath, a, 0644)
err = os.WriteFile(actionsPath, a, 0644)
if err != nil {
return errors.Wrap(err, "writing the result into a file")
}
@ -750,7 +749,7 @@ func migrateToV5(ctx context.DnoteCtx) error {
func migrateToV6(ctx context.DnoteCtx) error {
notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote)
b, err := ioutil.ReadFile(notePath)
b, err := os.ReadFile(notePath)
if err != nil {
return errors.Wrap(err, "Failed to read the note file")
}
@ -791,7 +790,7 @@ func migrateToV6(ctx context.DnoteCtx) error {
return errors.Wrap(err, "Failed to marshal new dnote into JSON")
}
err = ioutil.WriteFile(notePath, d, 0644)
err = os.WriteFile(notePath, d, 0644)
if err != nil {
return errors.Wrap(err, "Failed to write the new dnote into the file")
}
@ -805,7 +804,7 @@ func migrateToV6(ctx context.DnoteCtx) error {
func migrateToV7(ctx context.DnoteCtx) error {
actionPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote)
b, err := ioutil.ReadFile(actionPath)
b, err := os.ReadFile(actionPath)
if err != nil {
return errors.Wrap(err, "reading actions file")
}
@ -857,7 +856,7 @@ func migrateToV7(ctx context.DnoteCtx) error {
return errors.Wrap(err, "marshalling new actions")
}
err = ioutil.WriteFile(actionPath, d, 0644)
err = os.WriteFile(actionPath, d, 0644)
if err != nil {
return errors.Wrap(err, "writing new actions to a file")
}
@ -874,7 +873,7 @@ func migrateToV8(ctx context.DnoteCtx) error {
// 1. Migrate the the dnote file
dnoteFilePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote)
b, err := ioutil.ReadFile(dnoteFilePath)
b, err := os.ReadFile(dnoteFilePath)
if err != nil {
return errors.Wrap(err, "reading the notes")
}
@ -914,7 +913,7 @@ func migrateToV8(ctx context.DnoteCtx) error {
// 2. Migrate the actions file
actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote)
b, err = ioutil.ReadFile(actionsPath)
b, err = os.ReadFile(actionsPath)
if err != nil {
return errors.Wrap(err, "reading the actions")
}
@ -939,7 +938,7 @@ func migrateToV8(ctx context.DnoteCtx) error {
// 3. Migrate the timestamps file
timestampsPath := fmt.Sprintf("%s/timestamps", ctx.Paths.LegacyDnote)
b, err = ioutil.ReadFile(timestampsPath)
b, err = os.ReadFile(timestampsPath)
if err != nil {
return errors.Wrap(err, "reading the timestamps")
}

View file

@ -21,7 +21,6 @@ package migrate
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"testing"
@ -65,7 +64,7 @@ func TestMigrateToV1(t *testing.T) {
if err != nil {
panic(errors.Wrap(err, "Failed to get absolute YAML path").Error())
}
ioutil.WriteFile(yamlPath, []byte{}, 0644)
os.WriteFile(yamlPath, []byte{}, 0644)
// execute
if err := migrateToV1(ctx); err != nil {

View file

@ -21,8 +21,8 @@ package migrate
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"net/http/httptest"
"testing"
"time"
@ -1079,7 +1079,7 @@ func TestLocalMigration12(t *testing.T) {
data := []byte("editor: vim")
path := fmt.Sprintf("%s/%s/dnoterc", ctx.Paths.Config, consts.DnoteDirName)
if err := ioutil.WriteFile(path, data, 0644); err != nil {
if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatal(errors.Wrap(err, "Failed to write schema file"))
}
@ -1090,7 +1090,7 @@ func TestLocalMigration12(t *testing.T) {
}
// test
b, err := ioutil.ReadFile(path)
b, err := os.ReadFile(path)
if err != nil {
t.Fatal(errors.Wrap(err, "reading config"))
}
@ -1117,7 +1117,7 @@ func TestLocalMigration13(t *testing.T) {
data := []byte("editor: vim\napiEndpoint: https://test.com/api")
path := fmt.Sprintf("%s/%s/dnoterc", ctx.Paths.Config, consts.DnoteDirName)
if err := ioutil.WriteFile(path, data, 0644); err != nil {
if err := os.WriteFile(path, data, 0644); err != nil {
t.Fatal(errors.Wrap(err, "Failed to write schema file"))
}
@ -1128,7 +1128,7 @@ func TestLocalMigration13(t *testing.T) {
}
// test
b, err := ioutil.ReadFile(path)
b, err := os.ReadFile(path)
if err != nil {
t.Fatal(errors.Wrap(err, "reading config"))
}

View file

@ -23,7 +23,6 @@ import (
"bytes"
"encoding/json"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
@ -81,7 +80,7 @@ func WriteFile(ctx context.DnoteCtx, content []byte, filename string) {
panic(err)
}
if err := ioutil.WriteFile(dp, content, 0644); err != nil {
if err := os.WriteFile(dp, content, 0644); err != nil {
panic(err)
}
}
@ -90,7 +89,7 @@ func WriteFile(ctx context.DnoteCtx, content []byte, filename string) {
func ReadFile(ctx context.DnoteCtx, filename string) []byte {
path := filepath.Join(ctx.Paths.LegacyDnote, filename)
b, err := ioutil.ReadFile(path)
b, err := os.ReadFile(path)
if err != nil {
panic(err)
}
@ -101,7 +100,7 @@ func ReadFile(ctx context.DnoteCtx, filename string) []byte {
// ReadJSON reads JSON fixture to the struct at the destination address
func ReadJSON(path string, destination interface{}) {
var dat []byte
dat, err := ioutil.ReadFile(path)
dat, err := os.ReadFile(path)
if err != nil {
panic(errors.Wrap(err, "Failed to load fixture payload"))
}

View file

@ -21,7 +21,6 @@ package ui
import (
"fmt"
"io/ioutil"
"os"
"os/exec"
"strings"
@ -122,7 +121,7 @@ func GetEditorInput(ctx context.DnoteCtx, fpath string) (string, error) {
return "", errors.Wrap(err, "waiting for the editor")
}
b, err := ioutil.ReadFile(fpath)
b, err := os.ReadFile(fpath)
if err != nil {
return "", errors.Wrap(err, "reading the temporary content file")
}

View file

@ -20,7 +20,6 @@ package utils
import (
"io"
"io/ioutil"
"os"
"path/filepath"
@ -35,7 +34,7 @@ func ReadFileAbs(relpath string) []byte {
panic(err)
}
b, err := ioutil.ReadFile(fp)
b, err := os.ReadFile(fp)
if err != nil {
panic(err)
}
@ -80,7 +79,7 @@ func CopyDir(src, dest string) error {
return errors.Wrap(err, "creating destination")
}
entries, err := ioutil.ReadDir(src)
entries, err := os.ReadDir(src)
if err != nil {
return errors.Wrap(err, "reading the directory listing for the input")
}

View file

@ -19,7 +19,6 @@
package dirs
import (
"os"
"testing"
"github.com/dnote/dnote/pkg/assert"
@ -34,7 +33,7 @@ type envTestCase struct {
func testCustomDirs(t *testing.T, testCases []envTestCase) {
for _, tc := range testCases {
os.Setenv(tc.envKey, tc.envVal)
t.Setenv(tc.envKey, tc.envVal)
Reload()

183
pkg/e2e/server_test.go Normal file
View file

@ -0,0 +1,183 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package main
import (
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"testing"
"time"
"github.com/dnote/dnote/pkg/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var testServerBinary string
func init() {
// Build server binary in temp directory
tmpDir := os.TempDir()
testServerBinary = fmt.Sprintf("%s/dnote-test-server", tmpDir)
buildCmd := exec.Command("go", "build", "-tags", "fts5", "-o", testServerBinary, "../server")
if out, err := buildCmd.CombinedOutput(); err != nil {
panic(fmt.Sprintf("failed to build server: %v\n%s", err, out))
}
}
func TestServerStart(t *testing.T) {
tmpDB := t.TempDir() + "/test.db"
port := "13456" // Use different port to avoid conflicts with main test server
// Start server in background
cmd := exec.Command(testServerBinary, "start", "-port", port)
cmd.Env = append(os.Environ(),
"DBPath="+tmpDB,
"WebURL=http://localhost:"+port,
"APP_ENV=PRODUCTION",
)
if err := cmd.Start(); err != nil {
t.Fatalf("failed to start server: %v", err)
}
// Ensure cleanup
cleanup := func() {
if cmd.Process != nil {
cmd.Process.Kill()
cmd.Wait() // Wait for process to fully exit
}
}
defer cleanup()
// Wait for server to start and migrations to run
time.Sleep(3 * time.Second)
// Verify server responds to health check
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/health", port))
if err != nil {
t.Fatalf("failed to reach server health endpoint: %v", err)
}
defer resp.Body.Close()
assert.Equal(t, resp.StatusCode, 200, "health endpoint should return 200")
// Kill server before checking database to avoid locks
cleanup()
// Verify database file was created
if _, err := os.Stat(tmpDB); os.IsNotExist(err) {
t.Fatalf("database file was not created at %s", tmpDB)
}
// Verify migrations ran by checking database
db, err := gorm.Open(sqlite.Open(tmpDB), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open test database: %v", err)
}
// Verify migrations ran
var count int64
if err := db.Raw("SELECT COUNT(*) FROM schema_migrations").Scan(&count).Error; err != nil {
t.Fatalf("schema_migrations table not found: %v", err)
}
if count == 0 {
t.Fatal("no migrations were run")
}
// Verify FTS table exists and is functional
if err := db.Exec("SELECT * FROM notes_fts LIMIT 1").Error; err != nil {
t.Fatalf("notes_fts table not found or not functional: %v", err)
}
}
func TestServerVersion(t *testing.T) {
cmd := exec.Command("go", "run", "-tags", "fts5", "../server", "version")
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("version command failed: %v", err)
}
outputStr := string(output)
if !strings.Contains(outputStr, "dnote-server-") {
t.Errorf("expected version output to contain 'dnote-server-', got: %s", outputStr)
}
}
func TestServerRootCommand(t *testing.T) {
cmd := exec.Command(testServerBinary)
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("server command failed: %v", err)
}
outputStr := string(output)
assert.Equal(t, strings.Contains(outputStr, "Dnote server - a simple command line notebook"), true, "output should contain description")
assert.Equal(t, strings.Contains(outputStr, "start: Start the server"), true, "output should contain start command")
assert.Equal(t, strings.Contains(outputStr, "version: Print the version"), true, "output should contain version command")
}
func TestServerStartHelp(t *testing.T) {
cmd := exec.Command(testServerBinary, "start", "--help")
output, _ := cmd.CombinedOutput()
outputStr := string(output)
assert.Equal(t, strings.Contains(outputStr, "dnote-server start [flags]"), true, "output should contain usage")
assert.Equal(t, strings.Contains(outputStr, "-appEnv"), true, "output should contain appEnv flag")
assert.Equal(t, strings.Contains(outputStr, "-port"), true, "output should contain port flag")
assert.Equal(t, strings.Contains(outputStr, "-webUrl"), true, "output should contain webUrl flag")
assert.Equal(t, strings.Contains(outputStr, "-dbPath"), true, "output should contain dbPath flag")
assert.Equal(t, strings.Contains(outputStr, "-disableRegistration"), true, "output should contain disableRegistration flag")
}
func TestServerStartInvalidConfig(t *testing.T) {
cmd := exec.Command(testServerBinary, "start")
// Clear WebURL env var so validation fails
cmd.Env = []string{}
output, err := cmd.CombinedOutput()
// Should exit with non-zero status
if err == nil {
t.Fatal("expected command to fail with invalid config")
}
outputStr := string(output)
assert.Equal(t, strings.Contains(outputStr, "Error:"), true, "output should contain error message")
assert.Equal(t, strings.Contains(outputStr, "Invalid WebURL"), true, "output should mention invalid WebURL")
assert.Equal(t, strings.Contains(outputStr, "dnote-server start [flags]"), true, "output should show usage")
assert.Equal(t, strings.Contains(outputStr, "-webUrl"), true, "output should show flags")
}
func TestServerUnknownCommand(t *testing.T) {
cmd := exec.Command(testServerBinary, "unknown")
output, err := cmd.CombinedOutput()
// Should exit with non-zero status
if err == nil {
t.Fatal("expected command to fail with unknown command")
}
outputStr := string(output)
assert.Equal(t, strings.Contains(outputStr, "Unknown command"), true, "output should contain unknown command message")
assert.Equal(t, strings.Contains(outputStr, "Dnote server - a simple command line notebook"), true, "output should show help")
}

View file

@ -22,7 +22,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"net/http/httptest"
@ -39,16 +39,17 @@ import (
clitest "github.com/dnote/dnote/pkg/cli/testutils"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/controllers"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/mailer"
apitest "github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
"gorm.io/gorm"
)
var cliBinaryName string
var server *httptest.Server
var serverDb *gorm.DB
var serverTime = time.Date(2017, time.March, 14, 21, 15, 0, 0, time.UTC)
var tmpDirPath string
@ -82,22 +83,22 @@ func clearTmp(t *testing.T) {
}
func TestMain(m *testing.M) {
// Set up server database
apitest.InitTestDB()
// Set up server database - use file-based DB for e2e tests
dbPath := fmt.Sprintf("%s/server.db", testDir)
serverDb = apitest.InitDB(dbPath)
mockClock := clock.NewMock()
mockClock.SetNow(serverTime)
a := app.NewTest()
a.Clock = mockClock
a.EmailTemplates = mailer.Templates{}
a.EmailBackend = &apitest.MockEmailbackendImplementation{}
a.DB = serverDb
a.WebURL = os.Getenv("WebURL")
var err error
server, err = controllers.NewServer(&app.App{
Clock: mockClock,
EmailTemplates: mailer.Templates{},
EmailBackend: &apitest.MockEmailbackendImplementation{},
DB: apitest.DB,
Config: config.Config{
WebURL: os.Getenv("WebURL"),
},
})
server, err = controllers.NewServer(&a)
if err != nil {
panic(errors.Wrap(err, "initializing router"))
}
@ -124,11 +125,11 @@ func TestMain(m *testing.M) {
// helpers
func setupUser(t *testing.T, ctx *context.DnoteCtx) database.User {
user := apitest.SetupUserData()
apitest.SetupAccountData(user, "alice@example.com", "pass1234")
user := apitest.SetupUserData(serverDb)
apitest.SetupAccountData(serverDb, user, "alice@example.com", "pass1234")
// log in the user in CLI
session := apitest.SetupSession(t, user)
session := apitest.SetupSession(serverDb, user)
cliDatabase.MustExec(t, "inserting session_key", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKey, session.Key)
cliDatabase.MustExec(t, "inserting session_key_expiry", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKeyExpiry, session.ExpiresAt.Unix())
@ -184,9 +185,9 @@ func doHTTPReq(t *testing.T, method, path, payload, message string, user databas
panic(errors.Wrap(err, "constructing http request"))
}
res := apitest.HTTPAuthDo(t, req, user)
res := apitest.HTTPAuthDo(t, serverDb, req, user)
if res.StatusCode >= 400 {
bs, err := ioutil.ReadAll(res.Body)
bs, err := io.ReadAll(res.Body)
if err != nil {
panic(errors.Wrap(err, "parsing response body for error"))
}
@ -202,8 +203,8 @@ type assertFunc func(t *testing.T, ctx context.DnoteCtx, user database.User, ids
func testSyncCmd(t *testing.T, fullSync bool, setup setupFunc, assert assertFunc) {
// clean up
apitest.ClearData(apitest.DB)
defer apitest.ClearData(apitest.DB)
apitest.ClearData(serverDb)
defer apitest.ClearData(serverDb)
clearTmp(t)
@ -234,7 +235,6 @@ type systemState struct {
// checkState compares the state of the client and the server with the given system state
func checkState(t *testing.T, ctx context.DnoteCtx, user database.User, expected systemState) {
serverDB := apitest.DB
clientDB := ctx.DB
var clientBookCount, clientNoteCount int
@ -251,12 +251,12 @@ func checkState(t *testing.T, ctx context.DnoteCtx, user database.User, expected
assert.Equal(t, clientLastSyncAt, expected.clientLastSyncAt, "client last_sync_at mismatch")
var serverBookCount, serverNoteCount int64
apitest.MustExec(t, serverDB.Model(&database.Note{}).Count(&serverNoteCount), "counting server notes")
apitest.MustExec(t, serverDB.Model(&database.Book{}).Count(&serverBookCount), "counting api notes")
apitest.MustExec(t, serverDb.Model(&database.Note{}).Count(&serverNoteCount), "counting server notes")
apitest.MustExec(t, serverDb.Model(&database.Book{}).Count(&serverBookCount), "counting api notes")
assert.Equal(t, serverNoteCount, expected.serverNoteCount, "server note count mismatch")
assert.Equal(t, serverBookCount, expected.serverBookCount, "server book count mismatch")
var serverUser database.User
apitest.MustExec(t, serverDB.Where("id = ?", user.ID).First(&serverUser), "finding user")
apitest.MustExec(t, serverDb.Where("id = ?", user.ID).First(&serverUser), "finding user")
assert.Equal(t, serverUser.MaxUSN, expected.serverUserMaxUSN, "user max_usn mismatch")
}
@ -286,8 +286,7 @@ func TestSync_Empty(t *testing.T) {
func TestSync_oneway(t *testing.T) {
t.Run("cli to api only", func(t *testing.T) {
setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) {
apiDB := apitest.DB
apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn")
apitest.MustExec(t, serverDb.Model(&user).Update("max_usn", 0), "updating user max_usn")
clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "js", "-c", "js1")
clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "css", "-c", "css1")
@ -295,7 +294,6 @@ func TestSync_oneway(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User) {
apiDB := apitest.DB
cliDB := ctx.DB
// test client
@ -339,11 +337,11 @@ func TestSync_oneway(t *testing.T) {
// test server
var apiBookJS, apiBookCSS database.Book
var apiNote1JS, apiNote2JS, apiNote1CSS database.Note
apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote1JS.UUID).First(&apiNote1JS), "getting js1 note")
apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote2JS.UUID).First(&apiNote2JS), "getting js2 note")
apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote1CSS.UUID).First(&apiNote1CSS), "getting css1 note")
apitest.MustExec(t, apiDB.Model(&database.Book{}).Where("uuid = ?", cliBookJS.UUID).First(&apiBookJS), "getting js book")
apitest.MustExec(t, apiDB.Model(&database.Book{}).Where("uuid = ?", cliBookCSS.UUID).First(&apiBookCSS), "getting css book")
apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote1JS.UUID).First(&apiNote1JS), "getting js1 note")
apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote2JS.UUID).First(&apiNote2JS), "getting js2 note")
apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote1CSS.UUID).First(&apiNote1CSS), "getting css1 note")
apitest.MustExec(t, serverDb.Model(&database.Book{}).Where("uuid = ?", cliBookJS.UUID).First(&apiBookJS), "getting js book")
apitest.MustExec(t, serverDb.Model(&database.Book{}).Where("uuid = ?", cliBookCSS.UUID).First(&apiBookCSS), "getting css book")
// assert usn
assert.NotEqual(t, apiNote1JS.USN, 0, "apiNote1JS usn mismatch")
@ -371,7 +369,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("stepSync", func(t *testing.T) {
clearTmp(t)
defer apitest.ClearData(apitest.DB)
defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@ -385,7 +383,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("fullSync", func(t *testing.T) {
clearTmp(t)
defer apitest.ClearData(apitest.DB)
defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@ -400,7 +398,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("cli to api with edit and delete", func(t *testing.T) {
setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) {
apiDB := apitest.DB
apiDB := serverDb
apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn")
clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "js", "-c", "js1")
@ -423,7 +421,7 @@ func TestSync_oneway(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 6,
@ -527,7 +525,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("stepSync", func(t *testing.T) {
clearTmp(t)
defer apitest.ClearData(apitest.DB)
defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@ -541,7 +539,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("fullSync", func(t *testing.T) {
clearTmp(t)
defer apitest.ClearData(apitest.DB)
defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@ -556,7 +554,7 @@ func TestSync_oneway(t *testing.T) {
t.Run("api to cli", func(t *testing.T) {
setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) map[string]string {
apiDB := apitest.DB
apiDB := serverDb
apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn")
@ -603,7 +601,7 @@ func TestSync_oneway(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 6,
@ -804,7 +802,7 @@ func TestSync_twoway(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 9,
@ -1033,7 +1031,7 @@ func TestSync_twoway(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
apiDB := apitest.DB
apiDB := serverDb
cliDB := ctx.DB
checkState(t, ctx, user, systemState{
@ -1188,7 +1186,7 @@ func TestSync_twoway(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 4,
@ -1287,7 +1285,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -1349,7 +1347,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -1404,7 +1402,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -1471,7 +1469,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -1537,7 +1535,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -1601,7 +1599,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -1655,7 +1653,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -1708,7 +1706,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -1750,7 +1748,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -1816,7 +1814,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -1884,7 +1882,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -1957,7 +1955,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -2021,7 +2019,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -2081,7 +2079,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
apiDB := apitest.DB
apiDB := serverDb
cliDB := ctx.DB
checkState(t, ctx, user, systemState{
@ -2150,7 +2148,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@ -2218,7 +2216,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@ -2298,7 +2296,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 4,
@ -2404,7 +2402,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@ -2472,7 +2470,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@ -2555,7 +2553,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
resolvedBody := "<<<<<<< Local\njs1-edited-from-client\n=======\njs1-edited-from-server\n>>>>>>> Server\n"
@ -2630,7 +2628,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -2705,7 +2703,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -2789,7 +2787,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -2862,7 +2860,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -2934,7 +2932,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -3005,7 +3003,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -3072,7 +3070,7 @@ func TestSync(t *testing.T) {
}
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 0,
@ -3127,7 +3125,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -3203,7 +3201,7 @@ func TestSync(t *testing.T) {
// In this case, server's change wins and overwrites that of client's
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -3278,7 +3276,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -3365,7 +3363,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 1,
@ -3454,7 +3452,7 @@ func TestSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
expectedNote1JSBody := `<<<<<<< Local
Moved to the book linux
@ -3566,7 +3564,7 @@ js1`
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@ -3668,7 +3666,7 @@ js1`
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@ -3768,7 +3766,7 @@ func TestFullSync(t *testing.T) {
assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) {
cliDB := ctx.DB
apiDB := apitest.DB
apiDB := serverDb
checkState(t, ctx, user, systemState{
clientNoteCount: 2,
@ -3832,8 +3830,8 @@ func TestFullSync(t *testing.T) {
t.Run("stepSync then fullSync", func(t *testing.T) {
// clean up
os.RemoveAll(tmpDirPath)
apitest.ClearData(apitest.DB)
defer apitest.ClearData(apitest.DB)
apitest.ClearData(serverDb)
defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)
@ -3849,8 +3847,8 @@ func TestFullSync(t *testing.T) {
t.Run("fullSync then stepSync", func(t *testing.T) {
// clean up
os.RemoveAll(tmpDirPath)
apitest.ClearData(apitest.DB)
defer apitest.ClearData(apitest.DB)
apitest.ClearData(serverDb)
defer apitest.ClearData(serverDb)
ctx := context.InitTestCtx(t, paths, nil)
defer context.TeardownTestCtx(t, ctx)

View file

@ -1,11 +1,4 @@
GO_ENV=DEVELOPMENT
DBHost=localhost
DBPort=5432
DBName=dnote
DBUser=postgres
DBPassword=postgres
DBSkipSSL=true
APP_ENV=DEVELOPMENT
SmtpUsername=mock-SmtpUsername
SmtpPassword=mock-SmtpPassword
@ -14,4 +7,3 @@ SmtpPort=465
WebURL=http://localhost:3000
DisableRegistration=false
OnPremise=true

View file

@ -1,11 +1,4 @@
GO_ENV=TEST
DBHost=localhost
DBPort=5432
DBName=dnote_test
DBUser=postgres
DBPassword=postgres
DBSkipSSL=true
APP_ENV=TEST
SmtpUsername=mock-SmtpUsername
SmtpPassword=mock-SmtpPassword

View file

@ -20,7 +20,6 @@ package app
import (
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/mailer"
"gorm.io/gorm"
"github.com/pkg/errors"
@ -43,18 +42,23 @@ var (
// App is an application context
type App struct {
DB *gorm.DB
Clock clock.Clock
EmailTemplates mailer.Templates
EmailBackend mailer.Backend
Config config.Config
Files map[string][]byte
HTTP500Page []byte
DB *gorm.DB
Clock clock.Clock
EmailTemplates mailer.Templates
EmailBackend mailer.Backend
Files map[string][]byte
HTTP500Page []byte
AppEnv string
WebURL string
DisableRegistration bool
Port string
DBPath string
AssetBaseURL string
}
// Validate validates the app configuration
func (a *App) Validate() error {
if a.Config.WebURL == "" {
if a.WebURL == "" {
return ErrEmptyWebURL
}
if a.Clock == nil {

View file

@ -54,17 +54,17 @@ func TestCreateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
user := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
a := NewTest(&App{
Clock: clock.NewMock(),
})
a := NewTest()
a.DB = db
a.Clock = clock.NewMock()
book, err := a.CreateBook(user, tc.label)
if err != nil {
@ -75,13 +75,13 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
if err := testutils.DB.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
if err := db.Model(&database.Book{}).Count(&bookCount).Error; err != nil {
t.Fatal(errors.Wrap(err, "counting books"))
}
if err := testutils.DB.First(&bookRecord).Error; err != nil {
if err := db.First(&bookRecord).Error; err != nil {
t.Fatal(errors.Wrap(err, "finding book"))
}
if err := testutils.DB.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
if err := db.Where("id = ?", user.ID).First(&userRecord).Error; err != nil {
t.Fatal(errors.Wrap(err, "finding user"))
}
@ -120,19 +120,20 @@ func TestDeleteBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
user := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
book := database.Book{UserID: user.ID, Label: "js", Deleted: false}
testutils.MustExec(t, testutils.DB.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
testutils.MustExec(t, db.Save(&book), fmt.Sprintf("preparing book for test case %d", idx))
tx := testutils.DB.Begin()
a := NewTest(nil)
tx := db.Begin()
a := NewTest()
a.DB = db
ret, err := a.DeleteBook(tx, user, book)
if err != nil {
tx.Rollback()
@ -144,9 +145,9 @@ func TestDeleteBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, int64(1), "book count mismatch")
assert.Equal(t, bookRecord.UserID, user.ID, "book user_id mismatch")
@ -198,23 +199,23 @@ func TestUpdateBook(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
user := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel}
testutils.MustExec(t, testutils.DB.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
testutils.MustExec(t, db.Save(&b), fmt.Sprintf("preparing book for test case %d", idx))
c := clock.NewMock()
a := NewTest(&App{
Clock: c,
})
a := NewTest()
a.DB = db
a.Clock = c
tx := testutils.DB.Begin()
tx := db.Begin()
book, err := a.UpdateBook(tx, user, b, tc.payloadLabel)
if err != nil {
tx.Rollback()
@ -226,9 +227,9 @@ func TestUpdateBook(t *testing.T) {
var bookCount int64
var bookRecord database.Book
var userRecord database.User
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx))
testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, int64(1), "book count mismatch")

View file

@ -23,7 +23,6 @@ import (
"net/url"
"strings"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/pkg/errors"
)
@ -31,12 +30,8 @@ import (
var defaultSender = "admin@getdnote.com"
// GetSenderEmail returns the sender email
func GetSenderEmail(c config.Config, want string) (string, error) {
if !c.OnPremises {
return want, nil
}
addr, err := getNoreplySender(c)
func GetSenderEmail(webURL, want string) (string, error) {
addr, err := getNoreplySender(webURL)
if err != nil {
return "", errors.Wrap(err, "getting sender email address")
}
@ -60,8 +55,8 @@ func getDomainFromURL(rawURL string) (string, error) {
return domain, nil
}
func getNoreplySender(c config.Config) (string, error) {
domain, err := getDomainFromURL(c.WebURL)
func getNoreplySender(webURL string) (string, error) {
domain, err := getDomainFromURL(webURL)
if err != nil {
return "", errors.Wrap(err, "parsing web url")
}
@ -74,13 +69,13 @@ func getNoreplySender(c config.Config) (string, error) {
func (a *App) SendVerificationEmail(email, tokenValue string) error {
body, err := a.EmailTemplates.Execute(mailer.EmailTypeEmailVerification, mailer.EmailKindText, mailer.EmailVerificationTmplData{
Token: tokenValue,
WebURL: a.Config.WebURL,
WebURL: a.WebURL,
})
if err != nil {
return errors.Wrapf(err, "executing reset verification template for %s", email)
}
from, err := GetSenderEmail(a.Config, defaultSender)
from, err := GetSenderEmail(a.WebURL, defaultSender)
if err != nil {
return errors.Wrap(err, "getting the sender email")
}
@ -96,13 +91,13 @@ func (a *App) SendVerificationEmail(email, tokenValue string) error {
func (a *App) SendWelcomeEmail(email string) error {
body, err := a.EmailTemplates.Execute(mailer.EmailTypeWelcome, mailer.EmailKindText, mailer.WelcomeTmplData{
AccountEmail: email,
WebURL: a.Config.WebURL,
WebURL: a.WebURL,
})
if err != nil {
return errors.Wrapf(err, "executing reset verification template for %s", email)
}
from, err := GetSenderEmail(a.Config, defaultSender)
from, err := GetSenderEmail(a.WebURL, defaultSender)
if err != nil {
return errors.Wrap(err, "getting the sender email")
}
@ -123,13 +118,13 @@ func (a *App) SendPasswordResetEmail(email, tokenValue string) error {
body, err := a.EmailTemplates.Execute(mailer.EmailTypeResetPassword, mailer.EmailKindText, mailer.EmailResetPasswordTmplData{
AccountEmail: email,
Token: tokenValue,
WebURL: a.Config.WebURL,
WebURL: a.WebURL,
})
if err != nil {
return errors.Wrapf(err, "executing reset password template for %s", email)
}
from, err := GetSenderEmail(a.Config, defaultSender)
from, err := GetSenderEmail(a.WebURL, defaultSender)
if err != nil {
return errors.Wrap(err, "getting the sender email")
}
@ -149,13 +144,13 @@ func (a *App) SendPasswordResetEmail(email, tokenValue string) error {
func (a *App) SendPasswordResetAlertEmail(email string) error {
body, err := a.EmailTemplates.Execute(mailer.EmailTypeResetPasswordAlert, mailer.EmailKindText, mailer.EmailResetPasswordAlertTmplData{
AccountEmail: email,
WebURL: a.Config.WebURL,
WebURL: a.WebURL,
})
if err != nil {
return errors.Wrapf(err, "executing reset password alert template for %s", email)
}
from, err := GetSenderEmail(a.Config, defaultSender)
from, err := GetSenderEmail(a.WebURL, defaultSender)
if err != nil {
return errors.Wrap(err, "getting the sender email")
}

View file

@ -23,157 +23,74 @@ import (
"testing"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestSendVerificationEmail(t *testing.T) {
testCases := []struct {
onPremise bool
expectedSender string
}{
{
onPremise: false,
expectedSender: "admin@getdnote.com",
},
{
onPremise: true,
expectedSender: "noreply@example.com",
},
emailBackend := testutils.MockEmailbackendImplementation{}
a := NewTest()
a.EmailBackend = &emailBackend
a.WebURL = "http://example.com"
if err := a.SendVerificationEmail("alice@example.com", "mockTokenValue"); err != nil {
t.Fatal(err, "failed to perform")
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) {
c := config.Load()
c.SetOnPremises(tc.onPremise)
c.WebURL = "http://example.com"
assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch")
assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
emailBackend := testutils.MockEmailbackendImplementation{}
a := NewTest(&App{
EmailBackend: &emailBackend,
Config: c,
})
if err := a.SendVerificationEmail("alice@example.com", "mockTokenValue"); err != nil {
t.Fatal(err, "failed to perform")
}
assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch")
assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
})
}
}
func TestSendWelcomeEmail(t *testing.T) {
testCases := []struct {
onPremise bool
expectedSender string
}{
{
onPremise: false,
expectedSender: "admin@getdnote.com",
},
{
onPremise: true,
expectedSender: "noreply@example.com",
},
emailBackend := testutils.MockEmailbackendImplementation{}
a := NewTest()
a.EmailBackend = &emailBackend
a.WebURL = "http://example.com"
if err := a.SendWelcomeEmail("alice@example.com"); err != nil {
t.Fatal(err, "failed to perform")
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) {
c := config.Load()
c.SetOnPremises(tc.onPremise)
c.WebURL = "http://example.com"
assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch")
assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
emailBackend := testutils.MockEmailbackendImplementation{}
a := NewTest(&App{
EmailBackend: &emailBackend,
Config: c,
})
if err := a.SendWelcomeEmail("alice@example.com"); err != nil {
t.Fatal(err, "failed to perform")
}
assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch")
assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
})
}
}
func TestSendPasswordResetEmail(t *testing.T) {
testCases := []struct {
onPremise bool
expectedSender string
}{
{
onPremise: false,
expectedSender: "admin@getdnote.com",
},
{
onPremise: true,
expectedSender: "noreply@example.com",
},
emailBackend := testutils.MockEmailbackendImplementation{}
a := NewTest()
a.EmailBackend = &emailBackend
a.WebURL = "http://example.com"
if err := a.SendPasswordResetEmail("alice@example.com", "mockTokenValue"); err != nil {
t.Fatal(err, "failed to perform")
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) {
c := config.Load()
c.SetOnPremises(tc.onPremise)
c.WebURL = "http://example.com"
assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch")
assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
emailBackend := testutils.MockEmailbackendImplementation{}
a := NewTest(&App{
EmailBackend: &emailBackend,
Config: c,
})
if err := a.SendPasswordResetEmail("alice@example.com", "mockTokenValue"); err != nil {
t.Fatal(err, "failed to perform")
}
assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch")
assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch")
assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch")
})
}
}
func TestGetSenderEmail(t *testing.T) {
testCases := []struct {
onPremise bool
webURL string
candidate string
expectedSender string
}{
{
onPremise: true,
webURL: "https://www.example.com",
candidate: "alice@getdnote.com",
expectedSender: "noreply@example.com",
},
{
onPremise: false,
webURL: "https://www.getdnote.com",
candidate: "alice@getdnote.com",
expectedSender: "alice@getdnote.com",
webURL: "https://www.example2.com",
expectedSender: "alice@example2.com",
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("on premise %t candidate %s", tc.onPremise, tc.candidate), func(t *testing.T) {
c := config.Load()
c.SetOnPremises(tc.onPremise)
c.WebURL = tc.webURL
got, err := GetSenderEmail(c, tc.candidate)
if err != nil {
t.Fatal(err, "failed to perform")
}
assert.Equal(t, got, tc.expectedSender, "result mismatch")
t.Run(fmt.Sprintf("web url %s", tc.webURL), func(t *testing.T) {
})
}
}

View file

@ -46,13 +46,13 @@ func TestIncremenetUserUSN(t *testing.T) {
// set up
for idx, tc := range testCases {
func() {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
user := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
// execute
tx := testutils.DB.Begin()
tx := db.Begin()
nextUSN, err := incrementUserUSN(tx, user.ID)
if err != nil {
t.Fatal(errors.Wrap(err, "incrementing the user usn"))
@ -61,7 +61,7 @@ func TestIncremenetUserUSN(t *testing.T) {
// test
var userRecord database.User
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, userRecord.MaxUSN, tc.expectedMaxUSN, fmt.Sprintf("user max_usn mismatch for case %d", idx))
assert.Equal(t, nextUSN, tc.expectedMaxUSN, fmt.Sprintf("next_usn mismatch for case %d", idx))

View file

@ -20,13 +20,12 @@ package app
import (
"errors"
"strings"
"time"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"gorm.io/gorm"
pkgErrors "github.com/pkg/errors"
"gorm.io/gorm"
)
// CreateNote creates a note with the next usn and updates the user's max_usn.
@ -194,24 +193,16 @@ type ftsParams struct {
HighlightAll bool
}
func getHeadlineOptions(params *ftsParams) string {
headlineOptions := []string{
"StartSel=<dnotehl>",
"StopSel=</dnotehl>",
"ShortWord=0",
}
func getFTSBodyExpression(params *ftsParams) string {
if params != nil && params.HighlightAll {
headlineOptions = append(headlineOptions, "HighlightAll=true")
} else {
headlineOptions = append(headlineOptions, "MaxFragments=3, MaxWords=50, MinWords=10")
return "highlight(notes_fts, 0, '<dnotehl>', '</dnotehl>') AS body"
}
return strings.Join(headlineOptions, ",")
return "snippet(notes_fts, 0, '<dnotehl>', '</dnotehl>', '...', 50) AS body"
}
func selectFTSFields(conn *gorm.DB, search string, params *ftsParams) *gorm.DB {
headlineOpts := getHeadlineOptions(params)
func selectFTSFields(conn *gorm.DB, params *ftsParams) *gorm.DB {
bodyExpr := getFTSBodyExpression(params)
return conn.Select(`
notes.id,
@ -225,8 +216,7 @@ notes.edited_on,
notes.usn,
notes.deleted,
notes.encrypted,
ts_headline('english_nostop', notes.body, plainto_tsquery('english_nostop', ?), ?) AS body
`, search, headlineOpts)
` + bodyExpr)
}
func getNotesBaseQuery(db *gorm.DB, userID int, q GetNotesParams) *gorm.DB {
@ -236,8 +226,9 @@ func getNotesBaseQuery(db *gorm.DB, userID int, q GetNotesParams) *gorm.DB {
)
if q.Search != "" {
conn = selectFTSFields(conn, q.Search, nil)
conn = conn.Where("tsv @@ plainto_tsquery('english_nostop', ?)", q.Search)
conn = selectFTSFields(conn, nil)
conn = conn.Joins("INNER JOIN notes_fts ON notes_fts.rowid = notes.id")
conn = conn.Where("notes_fts MATCH ?", q.Search)
}
if len(q.Books) > 0 {

View file

@ -20,6 +20,7 @@ package app
import (
"fmt"
"strings"
"testing"
"time"
@ -74,36 +75,34 @@ func TestCreateNote(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
user := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
fmt.Println(user)
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
a := NewTest(&App{
Clock: mockClock,
})
a := NewTest()
a.DB = db
a.Clock = mockClock
tx := testutils.DB.Begin()
if _, err := a.CreateNote(user, b1.UUID, "note content", tc.addedOn, tc.editedOn, false, ""); err != nil {
tx.Rollback()
t.Fatal(errors.Wrap(err, "deleting note"))
t.Fatal(errors.Wrapf(err, "creating note for test case %d", idx))
}
tx.Commit()
var bookCount, noteCount int64
var noteRecord database.Note
var userRecord database.User
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), fmt.Sprintf("counting notes for test case %d", idx))
testutils.MustExec(t, testutils.DB.First(&noteRecord), fmt.Sprintf("finding note for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx))
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), fmt.Sprintf("counting notes for test case %d", idx))
testutils.MustExec(t, db.First(&noteRecord), fmt.Sprintf("finding note for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, bookCount, int64(1), "book count mismatch")
assert.Equal(t, noteCount, int64(1), "note count mismatch")
@ -116,10 +115,41 @@ func TestCreateNote(t *testing.T) {
assert.Equal(t, noteRecord.EditedOn, tc.expectedEditedOn, "note EditedOn mismatch")
assert.Equal(t, userRecord.MaxUSN, tc.expectedUSN, "user max_usn mismatch")
// Assert FTS table is updated
var ftsBody string
testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBody), fmt.Sprintf("querying notes_fts for test case %d", idx))
assert.Equal(t, ftsBody, "note content", "FTS body mismatch")
var searchCount int64
testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "content").Scan(&searchCount), "searching notes_fts")
assert.Equal(t, searchCount, int64(1), "Note should still be searchable")
}()
}
}
func TestCreateNote_EmptyBody(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
a := NewTest()
a.DB = db
a.Clock = clock.NewMock()
// Create note with empty body
note, err := a.CreateNote(user, b1.UUID, "", nil, nil, false, "")
if err != nil {
t.Fatal(errors.Wrap(err, "creating note with empty body"))
}
// Assert FTS entry exists with empty body
var ftsBody string
testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBody), "querying notes_fts for empty body note")
assert.Equal(t, ftsBody, "", "FTS body should be empty for note created with empty body")
}
func TestUpdateNote(t *testing.T) {
testCases := []struct {
userUSN int
@ -137,35 +167,40 @@ func TestUpdateNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
user := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case")
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
anotherUser := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case")
b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1 for test case")
testutils.MustExec(t, db.Save(&b1), "preparing b1 for test case")
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
testutils.MustExec(t, testutils.DB.Save(&note), "preparing note for test case")
testutils.MustExec(t, db.Save(&note), "preparing note for test case")
// Assert FTS table has original content
var ftsBodyBefore string
testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBodyBefore), "querying notes_fts before update")
assert.Equal(t, ftsBodyBefore, "test content", "FTS body mismatch before update")
c := clock.NewMock()
content := "updated test content"
public := true
a := NewTest(&App{
Clock: c,
})
a := NewTest()
a.DB = db
a.Clock = c
tx := testutils.DB.Begin()
tx := db.Begin()
if _, err := a.UpdateNote(tx, user, note, &UpdateNoteParams{
Content: &content,
Public: &public,
}); err != nil {
tx.Rollback()
t.Fatal(errors.Wrap(err, "deleting note"))
t.Fatal(errors.Wrap(err, "updating note"))
}
tx.Commit()
@ -173,10 +208,10 @@ func TestUpdateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes for test case")
testutils.MustExec(t, testutils.DB.First(&noteRecord), "finding note for test case")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting book for test case")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes for test case")
testutils.MustExec(t, db.First(&noteRecord), "finding note for test case")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user for test case")
expectedUSN := tc.userUSN + 1
assert.Equal(t, bookCount, int64(1), "book count mismatch")
@ -187,10 +222,55 @@ func TestUpdateNote(t *testing.T) {
assert.Equal(t, noteRecord.Deleted, false, "note Deleted mismatch")
assert.Equal(t, noteRecord.USN, expectedUSN, "note USN mismatch")
assert.Equal(t, userRecord.MaxUSN, expectedUSN, "user MaxUSN mismatch")
// Assert FTS table is updated with new content
var ftsBodyAfter string
testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBodyAfter), "querying notes_fts after update")
assert.Equal(t, ftsBodyAfter, content, "FTS body mismatch after update")
var searchCount int64
testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "updated").Scan(&searchCount), "searching notes_fts")
assert.Equal(t, searchCount, int64(1), "Note should still be searchable")
})
}
}
func TestUpdateNote_SameContent(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note), "preparing note")
a := NewTest()
a.DB = db
a.Clock = clock.NewMock()
// Update note with same content
sameContent := "test content"
tx := db.Begin()
_, err := a.UpdateNote(tx, user, note, &UpdateNoteParams{
Content: &sameContent,
})
if err != nil {
tx.Rollback()
t.Fatal(errors.Wrap(err, "updating note with same content"))
}
tx.Commit()
// Assert FTS still has the same content
var ftsBody string
testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBody), "querying notes_fts after update")
assert.Equal(t, ftsBody, "test content", "FTS body should still be 'test content'")
// Assert it's still searchable
var searchCount int64
testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "test").Scan(&searchCount), "searching notes_fts")
assert.Equal(t, searchCount, int64(1), "Note should still be searchable")
}
func TestDeleteNote(t *testing.T) {
testCases := []struct {
userUSN int
@ -212,23 +292,29 @@ func TestDeleteNote(t *testing.T) {
for idx, tc := range testCases {
func() {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
user := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData()
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
anotherUser := testutils.SetupUserData(db)
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx))
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx))
note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID}
testutils.MustExec(t, testutils.DB.Save(&note), fmt.Sprintf("preparing note for test case %d", idx))
testutils.MustExec(t, db.Save(&note), fmt.Sprintf("preparing note for test case %d", idx))
a := NewTest(nil)
// Assert FTS table has content before delete
var ftsCountBefore int64
testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsCountBefore), fmt.Sprintf("counting notes_fts before delete for test case %d", idx))
assert.Equal(t, ftsCountBefore, int64(1), "FTS should have entry before delete")
tx := testutils.DB.Begin()
a := NewTest()
a.DB = db
tx := db.Begin()
ret, err := a.DeleteNote(tx, user, note)
if err != nil {
tx.Rollback()
@ -240,9 +326,9 @@ func TestDeleteNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), fmt.Sprintf("counting notes for test case %d", idx))
testutils.MustExec(t, testutils.DB.First(&noteRecord), fmt.Sprintf("finding note for test case %d", idx))
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), fmt.Sprintf("counting notes for test case %d", idx))
testutils.MustExec(t, db.First(&noteRecord), fmt.Sprintf("finding note for test case %d", idx))
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx))
assert.Equal(t, noteCount, int64(1), "note count mismatch")
@ -256,6 +342,184 @@ func TestDeleteNote(t *testing.T) {
assert.Equal(t, ret.Body, "", "note content mismatch")
assert.Equal(t, ret.Deleted, true, "note deleted flag mismatch")
assert.Equal(t, ret.USN, tc.expectedUSN, "note label mismatch")
// Assert FTS body is empty after delete (row still exists but content is cleared)
var ftsBody string
testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBody), fmt.Sprintf("querying notes_fts after delete for test case %d", idx))
assert.Equal(t, ftsBody, "", "FTS body should be empty after delete")
}()
}
}
func TestGetNotes_FTSSearch(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
// Create notes with different content
note1 := database.Note{UserID: user.ID, Deleted: false, Body: "foo bar baz bar", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note1), "preparing note1")
note2 := database.Note{UserID: user.ID, Deleted: false, Body: "hello run foo", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note2), "preparing note2")
note3 := database.Note{UserID: user.ID, Deleted: false, Body: "running quz succeeded", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note3), "preparing note3")
a := NewTest()
a.DB = db
a.Clock = clock.NewMock()
// Search "baz"
result, err := a.GetNotes(user.ID, GetNotesParams{
Search: "baz",
Encrypted: false,
Page: 1,
PerPage: 30,
})
if err != nil {
t.Fatal(errors.Wrap(err, "getting notes with FTS search"))
}
assert.Equal(t, result.Total, int64(1), "Should find 1 note with 'baz'")
assert.Equal(t, len(result.Notes), 1, "Should return 1 note")
for i, note := range result.Notes {
assert.Equal(t, strings.Contains(note.Body, "<dnotehl>baz</dnotehl>"), true, fmt.Sprintf("Note %d should contain highlighted dnote", i))
}
// Search for "running" - should return 1 note
result, err = a.GetNotes(user.ID, GetNotesParams{
Search: "running",
Encrypted: false,
Page: 1,
PerPage: 30,
})
if err != nil {
t.Fatal(errors.Wrap(err, "getting notes with FTS search for review"))
}
assert.Equal(t, result.Total, int64(2), "Should find 2 note with 'running'")
assert.Equal(t, len(result.Notes), 2, "Should return 2 notes")
assert.Equal(t, result.Notes[0].Body, "<dnotehl>running</dnotehl> quz succeeded", "Should return the review note with highlighting")
assert.Equal(t, result.Notes[1].Body, "hello <dnotehl>run</dnotehl> foo", "Should return the review note with highlighting")
// Search for non-existent term - should return 0 notes
result, err = a.GetNotes(user.ID, GetNotesParams{
Search: "nonexistent",
Encrypted: false,
Page: 1,
PerPage: 30,
})
if err != nil {
t.Fatal(errors.Wrap(err, "getting notes with FTS search for nonexistent"))
}
assert.Equal(t, result.Total, int64(0), "Should find 0 notes with 'nonexistent'")
assert.Equal(t, len(result.Notes), 0, "Should return 0 notes")
}
func TestGetNotes_FTSSearch_Snippet(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
// Create a long note to test snippet truncation with "..."
// The snippet limit is 50 tokens, so we generate enough words to exceed it
longBody := strings.Repeat("filler ", 100) + "the important keyword appears here"
longNote := database.Note{UserID: user.ID, Deleted: false, Body: longBody, BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&longNote), "preparing long note")
a := NewTest()
a.DB = db
a.Clock = clock.NewMock()
// Search for "keyword" in long note - should return snippet with "..."
result, err := a.GetNotes(user.ID, GetNotesParams{
Search: "keyword",
Encrypted: false,
Page: 1,
PerPage: 30,
})
if err != nil {
t.Fatal(errors.Wrap(err, "getting notes with FTS search for keyword"))
}
assert.Equal(t, result.Total, int64(1), "Should find 1 note with 'keyword'")
assert.Equal(t, len(result.Notes), 1, "Should return 1 note")
// The snippet should contain "..." to indicate truncation and the highlighted keyword
assert.Equal(t, strings.Contains(result.Notes[0].Body, "..."), true, "Snippet should contain '...' for truncation")
assert.Equal(t, strings.Contains(result.Notes[0].Body, "<dnotehl>keyword</dnotehl>"), true, "Snippet should contain highlighted keyword")
}
func TestGetNotes_FTSSearch_ShortWord(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
// Create notes with short words
note1 := database.Note{UserID: user.ID, Deleted: false, Body: "a b c", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note1), "preparing note1")
note2 := database.Note{UserID: user.ID, Deleted: false, Body: "d", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note2), "preparing note2")
a := NewTest()
a.DB = db
a.Clock = clock.NewMock()
result, err := a.GetNotes(user.ID, GetNotesParams{
Search: "a",
Encrypted: false,
Page: 1,
PerPage: 30,
})
if err != nil {
t.Fatal(errors.Wrap(err, "getting notes with FTS search for 'a'"))
}
assert.Equal(t, result.Total, int64(1), "Should find 1 note")
assert.Equal(t, len(result.Notes), 1, "Should return 1 note")
assert.Equal(t, strings.Contains(result.Notes[0].Body, "<dnotehl>a</dnotehl>"), true, "Should contain highlighted 'a'")
}
func TestGetNotes_All(t *testing.T) {
db := testutils.InitMemoryDB(t)
user := testutils.SetupUserData(db)
b1 := database.Book{UserID: user.ID, Label: "testBook"}
testutils.MustExec(t, db.Save(&b1), "preparing book")
note1 := database.Note{UserID: user.ID, Deleted: false, Body: "a b c", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note1), "preparing note1")
note2 := database.Note{UserID: user.ID, Deleted: false, Body: "d", BookUUID: b1.UUID}
testutils.MustExec(t, db.Save(&note2), "preparing note2")
a := NewTest()
a.DB = db
a.Clock = clock.NewMock()
result, err := a.GetNotes(user.ID, GetNotesParams{
Search: "",
Encrypted: false,
Page: 1,
PerPage: 30,
})
if err != nil {
t.Fatal(errors.Wrap(err, "getting notes with FTS search for 'a'"))
}
assert.Equal(t, result.Total, int64(2), "Should not find all notes")
assert.Equal(t, len(result.Notes), 2, "Should not find all notes")
for _, note := range result.Notes {
assert.Equal(t, strings.Contains(note.Body, "<dnotehl>"), false, "There should be no keywords")
assert.Equal(t, strings.Contains(note.Body, "</dnotehl>"), false, "There should be no keywords")
}
assert.Equal(t, result.Notes[0].Body, "d", "Full content should be returned")
assert.Equal(t, result.Notes[1].Body, "a b c", "Full content should be returned")
}

View file

@ -19,50 +19,24 @@
package app
import (
"fmt"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/assets"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/testutils"
)
// NewTest returns an app for a testing environment
func NewTest(appParams *App) App {
c := config.Load()
c.SetOnPremises(false)
a := App{
DB: testutils.DB,
Clock: clock.NewMock(),
EmailTemplates: mailer.NewTemplates(),
EmailBackend: &testutils.MockEmailbackendImplementation{},
Config: c,
HTTP500Page: []byte("<html></html>"),
func NewTest() App {
return App{
Clock: clock.NewMock(),
EmailTemplates: mailer.NewTemplates(),
EmailBackend: &testutils.MockEmailbackendImplementation{},
HTTP500Page: assets.MustGetHTTP500ErrorPage(),
AppEnv: "TEST",
WebURL: "http://127.0.0.0.1",
Port: "3000",
DisableRegistration: false,
DBPath: ":memory:",
AssetBaseURL: "",
}
// Allow to override with appParams
if appParams != nil && appParams.EmailBackend != nil {
a.EmailBackend = appParams.EmailBackend
}
if appParams != nil && appParams.Clock != nil {
a.Clock = appParams.Clock
}
if appParams != nil && appParams.EmailTemplates != nil {
a.EmailTemplates = appParams.EmailTemplates
}
if appParams != nil && appParams.Config.OnPremises {
a.Config.OnPremises = appParams.Config.OnPremises
}
if appParams != nil && appParams.Config.WebURL != "" {
a.Config.WebURL = appParams.Config.WebURL
}
if appParams != nil && appParams.Config.DisableRegistration {
a.Config.DisableRegistration = appParams.Config.DisableRegistration
}
fmt.Printf("%+v\n", appParams)
fmt.Printf("%+v\n", a)
return a
}

View file

@ -22,11 +22,11 @@ import (
"errors"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/helpers"
"github.com/dnote/dnote/pkg/server/log"
"github.com/dnote/dnote/pkg/server/token"
"gorm.io/gorm"
pkgErrors "github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// TouchLastLoginAt updates the last login timestamp
@ -39,17 +39,6 @@ func (a *App) TouchLastLoginAt(user database.User, tx *gorm.DB) error {
return nil
}
func createEmailPreference(user database.User, tx *gorm.DB) error {
p := database.EmailPreference{
UserID: user.ID,
}
if err := tx.Save(&p).Error; err != nil {
return pkgErrors.Wrap(err, "inserting email preference")
}
return nil
}
// CreateUser creates a user
func (a *App) CreateUser(email, password string, passwordConfirmation string) (database.User, error) {
if email == "" {
@ -80,16 +69,14 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
return database.User{}, pkgErrors.Wrap(err, "hashing password")
}
// Grant all privileges if self-hosting
var pro bool
if a.Config.OnPremises {
pro = true
} else {
pro = false
uuid, err := helpers.GenUUID()
if err != nil {
tx.Rollback()
return database.User{}, pkgErrors.Wrap(err, "generating UUID")
}
user := database.User{
Cloud: pro,
UUID: uuid,
}
if err = tx.Save(&user).Error; err != nil {
tx.Rollback()
@ -105,14 +92,6 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d
return database.User{}, pkgErrors.Wrap(err, "saving account")
}
if _, err := token.Create(tx, user.ID, database.TokenTypeEmailPreference); err != nil {
tx.Rollback()
return database.User{}, pkgErrors.Wrap(err, "creating email verificaiton token")
}
if err := createEmailPreference(user, tx); err != nil {
tx.Rollback()
return database.User{}, pkgErrors.Wrap(err, "creating email preference")
}
if err := a.TouchLastLoginAt(user, tx); err != nil {
tx.Rollback()
return database.User{}, pkgErrors.Wrap(err, "updating last login")

View file

@ -19,11 +19,9 @@
package app
import (
"fmt"
"testing"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
@ -31,65 +29,41 @@ import (
)
func TestCreateUser_ProValue(t *testing.T) {
testCases := []struct {
onPremises bool
expectedPro bool
}{
{
onPremises: true,
expectedPro: true,
},
{
onPremises: false,
expectedPro: false,
},
db := testutils.InitMemoryDB(t)
a := NewTest()
a.DB = db
if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("self hosting %t", tc.onPremises), func(t *testing.T) {
c := config.Load()
c.SetOnPremises(tc.onPremises)
var userCount int64
var userRecord database.User
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, db.First(&userRecord), "finding user")
defer testutils.ClearData(testutils.DB)
assert.Equal(t, userCount, int64(1), "book count mismatch")
a := NewTest(&App{
Config: c,
})
if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
var userCount int64
var userRecord database.User
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, testutils.DB.First(&userRecord), "finding user")
assert.Equal(t, userCount, int64(1), "book count mismatch")
assert.Equal(t, userRecord.Cloud, tc.expectedPro, "user pro mismatch")
})
}
}
func TestCreateUser(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
c := config.Load()
a := NewTest(&App{
Config: c,
})
a := NewTest()
a.DB = db
if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
var userCount int64
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
assert.Equal(t, userCount, int64(1), "book count mismatch")
var accountCount int64
var accountRecord database.Account
testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, testutils.DB.First(&accountRecord), "finding account")
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, db.First(&accountRecord), "finding account")
assert.Equal(t, accountCount, int64(1), "account count mismatch")
assert.Equal(t, accountRecord.Email.String, "alice@example.com", "account email mismatch")
@ -99,19 +73,20 @@ func TestCreateUser(t *testing.T) {
})
t.Run("duplicate email", func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
aliceUser := testutils.SetupUserData()
testutils.SetupAccountData(aliceUser, "alice@example.com", "somepassword")
aliceUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, aliceUser, "alice@example.com", "somepassword")
a := NewTest(nil)
a := NewTest()
a.DB = db
_, err := a.CreateUser("alice@example.com", "newpassword", "newpassword")
assert.Equal(t, err, ErrDuplicateEmail, "error mismatch")
var userCount, accountCount int64
testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account")
testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user")
testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account")
assert.Equal(t, userCount, int64(1), "user count mismatch")
assert.Equal(t, accountCount, int64(1), "account count mismatch")

View file

@ -19,149 +19,94 @@
package config
import (
"fmt"
"net/url"
"os"
"path/filepath"
"github.com/dnote/dnote/pkg/dirs"
"github.com/dnote/dnote/pkg/server/assets"
"github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
)
const (
// AppEnvProduction represents an app environment for production.
AppEnvProduction string = "PRODUCTION"
// DefaultDBDir is the default directory name for Dnote data
DefaultDBDir = "dnote"
// DefaultDBFilename is the default database filename
DefaultDBFilename = "server.db"
)
var (
// ErrDBMissingHost is an error for an incomplete configuration missing the host
ErrDBMissingHost = errors.New("DB Host is empty")
// ErrDBMissingPort is an error for an incomplete configuration missing the port
ErrDBMissingPort = errors.New("DB Port is empty")
// ErrDBMissingName is an error for an incomplete configuration missing the name
ErrDBMissingName = errors.New("DB Name is empty")
// ErrDBMissingUser is an error for an incomplete configuration missing the user
ErrDBMissingUser = errors.New("DB User is empty")
// DefaultDBPath is the default path to the database file
DefaultDBPath = filepath.Join(dirs.DataHome, DefaultDBDir, DefaultDBFilename)
)
var (
// ErrDBMissingPath is an error for an incomplete configuration missing the database path
ErrDBMissingPath = errors.New("DB Path is empty")
// ErrWebURLInvalid is an error for an incomplete configuration with invalid web url
ErrWebURLInvalid = errors.New("Invalid WebURL")
// ErrPortInvalid is an error for an incomplete configuration with invalid port
ErrPortInvalid = errors.New("Invalid Port")
)
// PostgresConfig holds the postgres connection configuration.
type PostgresConfig struct {
SSLMode string
Host string
Port string
Name string
User string
Password string
}
func readBoolEnv(name string) bool {
if os.Getenv(name) == "true" {
return true
}
return false
return os.Getenv(name) == "true"
}
// checkSSLMode checks if SSL is required for the database connection
func checkSSLMode() bool {
// TODO: deprecate DB_NOSSL in favor of DBSkipSSL
if os.Getenv("DB_NOSSL") != "" {
return true
// getOrEnv returns value if non-empty, otherwise env var, otherwise default
func getOrEnv(value, envKey, defaultVal string) string {
if value != "" {
return value
}
if os.Getenv("DBSkipSSL") == "true" {
return true
}
return os.Getenv("GO_ENV") != "PRODUCTION"
}
func loadDBConfig() PostgresConfig {
var sslmode string
if checkSSLMode() {
sslmode = "disable"
} else {
sslmode = "require"
}
return PostgresConfig{
SSLMode: sslmode,
Host: os.Getenv("DBHost"),
Port: os.Getenv("DBPort"),
Name: os.Getenv("DBName"),
User: os.Getenv("DBUser"),
Password: os.Getenv("DBPassword"),
if env := os.Getenv(envKey); env != "" {
return env
}
return defaultVal
}
// Config is an application configuration
type Config struct {
AppEnv string
WebURL string
OnPremises bool
DisableRegistration bool
Port string
DB PostgresConfig
DBPath string
AssetBaseURL string
HTTP500Page []byte
LogLevel string
}
func getAppEnv() string {
// DEPRECATED
goEnv := os.Getenv("GO_ENV")
if goEnv != "" {
return goEnv
}
return os.Getenv("APP_ENV")
// Params are the configuration parameters for creating a new Config
type Params struct {
AppEnv string
Port string
WebURL string
DBPath string
DisableRegistration bool
LogLevel string
}
func checkDeprecatedEnvVars() {
if os.Getenv("OnPremise") != "" {
log.WithFields(log.Fields{}).Warn("Environment variable 'OnPremise' is deprecated. Please use OnPremises.")
}
}
// Load constructs and returns a new config based on the environment variables.
func Load() Config {
port := os.Getenv("PORT")
if port == "" {
port = "3000"
}
checkDeprecatedEnvVars()
// New constructs and returns a new validated config.
// Empty string params will fall back to environment variables and defaults.
func New(p Params) (Config, error) {
c := Config{
AppEnv: getAppEnv(),
WebURL: os.Getenv("WebURL"),
Port: port,
OnPremises: readBoolEnv("OnPremise") || readBoolEnv("OnPremises"),
DisableRegistration: readBoolEnv("DisableRegistration"),
DB: loadDBConfig(),
AssetBaseURL: "",
AppEnv: getOrEnv(p.AppEnv, "APP_ENV", AppEnvProduction),
Port: getOrEnv(p.Port, "PORT", "3000"),
WebURL: getOrEnv(p.WebURL, "WebURL", ""),
DBPath: getOrEnv(p.DBPath, "DBPath", DefaultDBPath),
DisableRegistration: p.DisableRegistration || readBoolEnv("DisableRegistration"),
LogLevel: getOrEnv(p.LogLevel, "LOG_LEVEL", "info"),
AssetBaseURL: "/static",
HTTP500Page: assets.MustGetHTTP500ErrorPage(),
}
if err := validate(c); err != nil {
panic(err)
return Config{}, err
}
return c
}
// SetOnPremises sets the OnPremise value
func (c *Config) SetOnPremises(val bool) {
c.OnPremises = val
}
// SetAssetBaseURL sets static dir for the confi
func (c *Config) SetAssetBaseURL(d string) {
c.AssetBaseURL = d
return c, nil
}
// IsProd checks if the app environment is configured to be production.
@ -171,31 +116,15 @@ func (c Config) IsProd() bool {
func validate(c Config) error {
if _, err := url.ParseRequestURI(c.WebURL); err != nil {
return errors.Wrapf(ErrWebURLInvalid, "provided: '%s'", c.WebURL)
return errors.Wrapf(ErrWebURLInvalid, "'%s'", c.WebURL)
}
if c.Port == "" {
return ErrPortInvalid
}
if c.DB.Host == "" {
return ErrDBMissingHost
}
if c.DB.Port == "" {
return ErrDBMissingPort
}
if c.DB.Name == "" {
return ErrDBMissingName
}
if c.DB.User == "" {
return ErrDBMissingUser
if c.DBPath == "" {
return ErrDBMissingPath
}
return nil
}
// GetConnectionStr returns a postgres connection string.
func (c PostgresConfig) GetConnectionStr() string {
return fmt.Sprintf(
"sslmode=%s host=%s port=%s dbname=%s user=%s password=%s",
c.SSLMode, c.Host, c.Port, c.Name, c.User, c.Password)
}

View file

@ -33,12 +33,7 @@ func TestValidate(t *testing.T) {
}{
{
config: Config{
DB: PostgresConfig{
Host: "mockHost",
Port: "5432",
Name: "mockDB",
User: "mockUser",
},
DBPath: "test.db",
WebURL: "http://mock.url",
Port: "3000",
},
@ -46,71 +41,21 @@ func TestValidate(t *testing.T) {
},
{
config: Config{
DB: PostgresConfig{
Port: "5432",
Name: "mockDB",
User: "mockUser",
},
DBPath: "",
WebURL: "http://mock.url",
Port: "3000",
},
expectedErr: ErrDBMissingHost,
expectedErr: ErrDBMissingPath,
},
{
config: Config{
DB: PostgresConfig{
Host: "mockHost",
Name: "mockDB",
User: "mockUser",
},
WebURL: "http://mock.url",
Port: "3000",
},
expectedErr: ErrDBMissingPort,
},
{
config: Config{
DB: PostgresConfig{
Host: "mockHost",
Port: "5432",
User: "mockUser",
},
WebURL: "http://mock.url",
Port: "3000",
},
expectedErr: ErrDBMissingName,
},
{
config: Config{
DB: PostgresConfig{
Host: "mockHost",
Port: "5432",
Name: "mockDB",
},
WebURL: "http://mock.url",
Port: "3000",
},
expectedErr: ErrDBMissingUser,
},
{
config: Config{
DB: PostgresConfig{
Host: "mockHost",
Port: "5432",
Name: "mockDB",
User: "mockUser",
},
DBPath: "test.db",
},
expectedErr: ErrWebURLInvalid,
},
{
config: Config{
DB: PostgresConfig{
Host: "mockHost",
Port: "5432",
Name: "mockDB",
User: "mockUser",
},
DBPath: "test.db",
WebURL: "http://mock.url",
},
expectedErr: ErrPortInvalid,

View file

@ -21,69 +21,78 @@ package controllers
import (
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"testing"
"time"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
"github.com/pkg/errors"
)
// truncateMicro rounds time to microsecond precision to match SQLite storage
func truncateMicro(t time.Time) time.Time {
return t.Round(time.Microsecond)
}
func TestGetBooks(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData()
testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
USN: 1123,
Deleted: false,
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "css",
USN: 1125,
Deleted: false,
}
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
testutils.MustExec(t, db.Save(&b2), "preparing b2")
b3 := database.Book{
UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "css",
USN: 1128,
Deleted: false,
}
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
testutils.MustExec(t, db.Save(&b3), "preparing b3")
b4 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "",
USN: 1129,
Deleted: true,
}
testutils.MustExec(t, testutils.DB.Save(&b4), "preparing b4")
testutils.MustExec(t, db.Save(&b4), "preparing b4")
// Execute
endpoint := "/api/v3/books"
req := testutils.MakeReq(server.URL, "GET", endpoint, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -94,66 +103,75 @@ func TestGetBooks(t *testing.T) {
}
var b1Record, b2Record database.Book
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
expected := []presenters.Book{
{
UUID: b2Record.UUID,
CreatedAt: b2Record.CreatedAt,
UpdatedAt: b2Record.UpdatedAt,
CreatedAt: truncateMicro(b2Record.CreatedAt),
UpdatedAt: truncateMicro(b2Record.UpdatedAt),
Label: b2Record.Label,
USN: b2Record.USN,
},
{
UUID: b1Record.UUID,
CreatedAt: b1Record.CreatedAt,
UpdatedAt: b1Record.UpdatedAt,
CreatedAt: truncateMicro(b1Record.CreatedAt),
UpdatedAt: truncateMicro(b1Record.UpdatedAt),
Label: b1Record.Label,
USN: b1Record.USN,
},
}
// Truncate payload timestamps to match SQLite precision
for i := range payload {
payload[i].CreatedAt = truncateMicro(payload[i].CreatedAt)
payload[i].UpdatedAt = truncateMicro(payload[i].UpdatedAt)
}
assert.DeepEqual(t, payload, expected, "payload mismatch")
}
func TestGetBooksByName(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData()
testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
testutils.MustExec(t, db.Save(&b2), "preparing b2")
b3 := database.Book{
UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
testutils.MustExec(t, db.Save(&b3), "preparing b3")
// Execute
endpoint := "/api/v3/books?name=js"
req := testutils.MakeReq(server.URL, "GET", endpoint, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -164,56 +182,64 @@ func TestGetBooksByName(t *testing.T) {
}
var b1Record database.Book
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
expected := []presenters.Book{
{
UUID: b1Record.UUID,
CreatedAt: b1Record.CreatedAt,
UpdatedAt: b1Record.UpdatedAt,
CreatedAt: truncateMicro(b1Record.CreatedAt),
UpdatedAt: truncateMicro(b1Record.UpdatedAt),
Label: b1Record.Label,
USN: b1Record.USN,
},
}
for i := range payload {
payload[i].CreatedAt = truncateMicro(payload[i].CreatedAt)
payload[i].UpdatedAt = truncateMicro(payload[i].UpdatedAt)
}
assert.DeepEqual(t, payload, expected, "payload mismatch")
}
func TestGetBook(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData()
testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
testutils.MustExec(t, db.Save(&b2), "preparing b2")
b3 := database.Book{
UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
testutils.MustExec(t, db.Save(&b3), "preparing b3")
// Execute
endpoint := fmt.Sprintf("/api/v3/books/%s", b1.UUID)
req := testutils.MakeReq(server.URL, "GET", endpoint, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -224,49 +250,53 @@ func TestGetBook(t *testing.T) {
}
var b1Record database.Book
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
expected := presenters.Book{
UUID: b1Record.UUID,
CreatedAt: b1Record.CreatedAt,
UpdatedAt: b1Record.UpdatedAt,
CreatedAt: truncateMicro(b1Record.CreatedAt),
UpdatedAt: truncateMicro(b1Record.UpdatedAt),
Label: b1Record.Label,
USN: b1Record.USN,
}
payload.CreatedAt = truncateMicro(payload.CreatedAt)
payload.UpdatedAt = truncateMicro(payload.UpdatedAt)
assert.DeepEqual(t, payload, expected, "payload mismatch")
}
func TestGetBookNonOwner(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
nonOwner := testutils.SetupUserData()
testutils.SetupAccountData(nonOwner, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
nonOwner := testutils.SetupUserData(db)
testutils.SetupAccountData(db, nonOwner, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
// Execute
endpoint := fmt.Sprintf("/api/v3/books/%s", b1.UUID)
req := testutils.MakeReq(server.URL, "GET", endpoint, "")
res := testutils.HTTPAuthDo(t, req, nonOwner)
res := testutils.HTTPAuthDo(t, db, req, nonOwner)
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@ -275,23 +305,23 @@ func TestGetBookNonOwner(t *testing.T) {
func TestCreateBook(t *testing.T) {
t.Run("success", func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`)
// Execute
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
@ -299,10 +329,10 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var bookCount, noteCount int64
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
maxUSN := 102
@ -323,39 +353,43 @@ func TestCreateBook(t *testing.T) {
Book: presenters.Book{
UUID: bookRecord.UUID,
USN: bookRecord.USN,
CreatedAt: bookRecord.CreatedAt,
UpdatedAt: bookRecord.UpdatedAt,
CreatedAt: truncateMicro(bookRecord.CreatedAt),
UpdatedAt: truncateMicro(bookRecord.UpdatedAt),
Label: "js",
},
}
got.Book.CreatedAt = truncateMicro(got.Book.CreatedAt)
got.Book.UpdatedAt = truncateMicro(got.Book.UpdatedAt)
assert.DeepEqual(t, got, expected, "payload mismatch")
})
t.Run("duplicate", func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
USN: 58,
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book data")
testutils.MustExec(t, db.Save(&b1), "preparing book data")
// Execute
req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`)
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusConflict, "")
@ -363,10 +397,10 @@ func TestCreateBook(t *testing.T) {
var bookRecord database.Book
var bookCount, noteCount int64
var userRecord database.User
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(1), "book count mismatch")
assert.Equalf(t, noteCount, int64(0), "note count mismatch")
@ -422,18 +456,18 @@ func TestUpdateBook(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: tc.bookUUID,
@ -441,18 +475,18 @@ func TestUpdateBook(t *testing.T) {
Label: tc.bookLabel,
Deleted: tc.bookDeleted,
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: b2UUID,
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
testutils.MustExec(t, db.Save(&b2), "preparing b2")
// Execute
endpoint := fmt.Sprintf("/api/v3/books/%s", tc.bookUUID)
req := testutils.MakeReq(server.URL, "PATCH", endpoint, tc.payload.ToJSON(t))
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, fmt.Sprintf("status code mismatch for test case %d", idx))
@ -460,10 +494,10 @@ func TestUpdateBook(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var noteCount, bookCount int64
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(2), "book count mismatch")
assert.Equalf(t, noteCount, int64(0), "note count mismatch")
@ -507,41 +541,44 @@ func TestDeleteBook(t *testing.T) {
for _, tc := range testCases {
t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 58), "preparing user max_usn")
anotherUser := testutils.SetupUserData()
testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn")
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
USN: 1,
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing a book data")
testutils.MustExec(t, db.Save(&b1), "preparing a book data")
b2 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: tc.label,
USN: 2,
Deleted: tc.deleted,
}
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing a book data")
testutils.MustExec(t, db.Save(&b2), "preparing a book data")
b3 := database.Book{
UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "linux",
USN: 3,
}
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing a book data")
testutils.MustExec(t, db.Save(&b3), "preparing a book data")
var n2Body string
if !tc.deleted {
@ -553,49 +590,54 @@ func TestDeleteBook(t *testing.T) {
}
n1 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n1 content",
USN: 4,
}
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing a note data")
testutils.MustExec(t, db.Save(&n1), "preparing a note data")
n2 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b2.UUID,
Body: n2Body,
USN: 5,
Deleted: tc.deleted,
}
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing a note data")
testutils.MustExec(t, db.Save(&n2), "preparing a note data")
n3 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b2.UUID,
Body: n3Body,
USN: 6,
Deleted: tc.deleted,
}
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing a note data")
testutils.MustExec(t, db.Save(&n3), "preparing a note data")
n4 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b2.UUID,
Body: "",
USN: 7,
Deleted: true,
}
testutils.MustExec(t, testutils.DB.Save(&n4), "preparing a note data")
testutils.MustExec(t, db.Save(&n4), "preparing a note data")
n5 := database.Note{
UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
BookUUID: b3.UUID,
Body: "n5 content",
USN: 8,
}
testutils.MustExec(t, testutils.DB.Save(&n5), "preparing a note data")
testutils.MustExec(t, db.Save(&n5), "preparing a note data")
// Execute
endpoint := fmt.Sprintf("/api/v3/books/%s", b2.UUID)
req := testutils.MakeReq(server.URL, "DELETE", endpoint, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -605,17 +647,17 @@ func TestDeleteBook(t *testing.T) {
var userRecord database.User
var bookCount, noteCount int64
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, testutils.DB.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
testutils.MustExec(t, testutils.DB.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
testutils.MustExec(t, testutils.DB.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
testutils.MustExec(t, testutils.DB.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
testutils.MustExec(t, testutils.DB.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
testutils.MustExec(t, testutils.DB.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1")
testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2")
testutils.MustExec(t, db.Where("id = ?", b3.ID).First(&b3Record), "finding b3")
testutils.MustExec(t, db.Where("id = ?", n1.ID).First(&n1Record), "finding n1")
testutils.MustExec(t, db.Where("id = ?", n2.ID).First(&n2Record), "finding n2")
testutils.MustExec(t, db.Where("id = ?", n3.ID).First(&n3Record), "finding n3")
testutils.MustExec(t, db.Where("id = ?", n4.ID).First(&n4Record), "finding n4")
testutils.MustExec(t, db.Where("id = ?", n5.ID).First(&n5Record), "finding n5")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equal(t, bookCount, int64(3), "book count mismatch")
assert.Equal(t, noteCount, int64(5), "note count mismatch")

View file

@ -24,16 +24,15 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/server/app"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestHealth(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
server := MustNewServer(t, &app.App{
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
server := MustNewServer(t, &a)
defer server.Close()
// Execute

View file

@ -61,13 +61,6 @@ func parseForm(r *http.Request, dst interface{}) error {
return parseValues(r.PostForm, dst)
}
func parseURLParams(r *http.Request, dst interface{}) error {
if err := r.ParseForm(); err != nil {
return err
}
return parseValues(r.Form, dst)
}
func parseValues(values url.Values, dst interface{}) error {
dec := schema.NewDecoder()

View file

@ -21,15 +21,13 @@ package controllers
import (
"os"
"testing"
"github.com/dnote/dnote/pkg/server/testutils"
"time"
)
func TestMain(m *testing.M) {
testutils.InitTestDB()
// Set timezone to UTC to match database timestamps
time.Local = time.UTC
code := m.Run()
testutils.ClearData(testutils.DB)
os.Exit(code)
}

View file

@ -19,13 +19,10 @@
package controllers
import (
"math"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"github.com/dnote/dnote/pkg/server/app"
"github.com/dnote/dnote/pkg/server/context"
@ -150,69 +147,6 @@ func (n *Notes) getNotes(r *http.Request) (app.GetNotesResult, app.GetNotesParam
return res, p, nil
}
type noteGroup struct {
Year int
Month int
Data []database.Note
}
type bucketKey struct {
year int
month time.Month
}
func groupNotes(notes []database.Note) []noteGroup {
ret := []noteGroup{}
buckets := map[bucketKey][]database.Note{}
for _, note := range notes {
year := note.UpdatedAt.Year()
month := note.UpdatedAt.Month()
key := bucketKey{year, month}
if _, ok := buckets[key]; !ok {
buckets[key] = []database.Note{}
}
buckets[key] = append(buckets[key], note)
}
keys := []bucketKey{}
for key := range buckets {
keys = append(keys, key)
}
sort.Slice(keys, func(i, j int) bool {
yearI := keys[i].year
yearJ := keys[j].year
monthI := keys[i].month
monthJ := keys[j].month
if yearI == yearJ {
return monthI < monthJ
}
return yearI < yearJ
})
for _, key := range keys {
group := noteGroup{
Year: key.year,
Month: int(key.month),
Data: buckets[key],
}
ret = append(ret, group)
}
return ret
}
func getMaxPage(page, total int) int {
tmp := float64(total) / float64(notesPerPage)
return int(math.Ceil(tmp))
}
// GetNotesResponse is a reponse by getNotesHandler
type GetNotesResponse struct {
Notes []presenters.Note `json:"notes"`

View file

@ -21,7 +21,7 @@ package controllers
import (
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"testing"
"time"
@ -29,7 +29,6 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/presenters"
"github.com/dnote/dnote/pkg/server/testutils"
@ -39,8 +38,8 @@ import (
func getExpectedNotePayload(n database.Note, b database.Book, u database.User) presenters.Note {
return presenters.Note{
UUID: n.UUID,
CreatedAt: n.CreatedAt,
UpdatedAt: n.UpdatedAt,
CreatedAt: truncateMicro(n.CreatedAt),
UpdatedAt: truncateMicro(n.UpdatedAt),
Body: n.Body,
AddedOn: n.AddedOn,
Public: n.Public,
@ -56,37 +55,41 @@ func getExpectedNotePayload(n database.Note, b database.Book, u database.User) p
}
func TestGetNotes(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData()
testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
anotherUser := testutils.SetupUserData(db)
testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234")
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
testutils.MustExec(t, db.Save(&b2), "preparing b2")
b3 := database.Book{
UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
Label: "css",
}
testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3")
testutils.MustExec(t, db.Save(&b3), "preparing b3")
n1 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n1 content",
@ -94,8 +97,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1")
testutils.MustExec(t, db.Save(&n1), "preparing n1")
n2 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n2 content",
@ -103,8 +107,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 11, 22, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2")
testutils.MustExec(t, db.Save(&n2), "preparing n2")
n3 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "n3 content",
@ -112,8 +117,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2017, time.January, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3")
testutils.MustExec(t, db.Save(&n3), "preparing n3")
n4 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b2.UUID,
Body: "n4 content",
@ -121,8 +127,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.September, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, testutils.DB.Save(&n4), "preparing n4")
testutils.MustExec(t, db.Save(&n4), "preparing n4")
n5 := database.Note{
UUID: testutils.MustUUID(t),
UserID: anotherUser.ID,
BookUUID: b3.UUID,
Body: "n5 content",
@ -130,8 +137,9 @@ func TestGetNotes(t *testing.T) {
Deleted: false,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, testutils.DB.Save(&n5), "preparing n5")
testutils.MustExec(t, db.Save(&n5), "preparing n5")
n6 := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "",
@ -139,13 +147,13 @@ func TestGetNotes(t *testing.T) {
Deleted: true,
AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(),
}
testutils.MustExec(t, testutils.DB.Save(&n6), "preparing n6")
testutils.MustExec(t, db.Save(&n6), "preparing n6")
// Execute
endpoint := "/api/v3/notes"
req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("%s?year=2018&month=8", endpoint), "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -156,8 +164,8 @@ func TestGetNotes(t *testing.T) {
}
var n2Record, n1Record database.Note
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record")
expected := GetNotesResponse{
Notes: []presenters.Note{
@ -171,44 +179,48 @@ func TestGetNotes(t *testing.T) {
}
func TestGetNote(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
anotherUser := testutils.SetupUserData()
user := testutils.SetupUserData(db)
anotherUser := testutils.SetupUserData(db)
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
privateNote := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "privateNote content",
Public: false,
}
testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote")
testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote")
publicNote := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: "publicNote content",
Public: true,
}
testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing publicNote")
testutils.MustExec(t, db.Save(&publicNote), "preparing publicNote")
deletedNote := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Deleted: true,
}
testutils.MustExec(t, testutils.DB.Save(&deletedNote), "preparing deletedNote")
testutils.MustExec(t, db.Save(&deletedNote), "preparing deletedNote")
getURL := func(noteUUID string) string {
return fmt.Sprintf("/api/v3/notes/%s", noteUUID)
@ -218,7 +230,7 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(publicNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -229,7 +241,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@ -239,7 +251,7 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(publicNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -250,7 +262,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@ -260,7 +272,7 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(publicNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
res := testutils.HTTPAuthDo(t, req, anotherUser)
res := testutils.HTTPAuthDo(t, db, req, anotherUser)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -271,7 +283,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@ -281,12 +293,12 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(privateNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
res := testutils.HTTPAuthDo(t, req, anotherUser)
res := testutils.HTTPAuthDo(t, db, req, anotherUser)
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@ -309,7 +321,7 @@ func TestGetNote(t *testing.T) {
}
var n2Record database.Note
testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record")
expected := getExpectedNotePayload(n2Record, b1, user)
assert.DeepEqual(t, payload, expected, "payload mismatch")
@ -324,7 +336,7 @@ func TestGetNote(t *testing.T) {
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@ -336,12 +348,12 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL("somerandomstring")
req := testutils.MakeReq(server.URL, "GET", url, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@ -353,12 +365,12 @@ func TestGetNote(t *testing.T) {
// Execute
url := getURL(deletedNote.UUID)
req := testutils.MakeReq(server.URL, "GET", url, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusNotFound, "")
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(errors.Wrap(err, "reading body"))
}
@ -368,31 +380,32 @@ func TestGetNote(t *testing.T) {
}
func TestCreateNote(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: testutils.MustUUID(t),
UserID: user.ID,
Label: "js",
USN: 58,
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
// Execute
dat := fmt.Sprintf(`{"book_uuid": "%s", "content": "note content"}`, b1.UUID)
req := testutils.MakeReq(server.URL, "POST", "/api/v3/notes", dat)
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusCreated, "")
@ -401,11 +414,11 @@ func TestCreateNote(t *testing.T) {
var bookRecord database.Book
var userRecord database.User
var bookCount, noteCount int64
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.First(&noteRecord), "finding note")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.First(&noteRecord), "finding note")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(1), "book count mismatch")
assert.Equalf(t, noteCount, int64(1), "note count mismatch")
@ -449,38 +462,39 @@ func TestDeleteNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 981), "preparing user max_usn")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn")
b1 := database.Book{
UUID: b1UUID,
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
note := database.Note{
UUID: testutils.MustUUID(t),
UserID: user.ID,
BookUUID: b1.UUID,
Body: tc.content,
Deleted: tc.deleted,
USN: tc.originalUSN,
}
testutils.MustExec(t, testutils.DB.Save(&note), "preparing note")
testutils.MustExec(t, db.Save(&note), "preparing note")
// Execute
endpoint := fmt.Sprintf("/api/v3/notes/%s", note.UUID)
req := testutils.MakeReq(server.URL, "DELETE", endpoint, "")
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "")
@ -489,11 +503,11 @@ func TestDeleteNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
var bookCount, noteCount int64
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(&noteRecord), "finding note")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(&noteRecord), "finding note")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(1), "book count mismatch")
assert.Equalf(t, noteCount, int64(1), "note count mismatch")
@ -687,42 +701,42 @@ func TestUpdateNote(t *testing.T) {
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
defer testutils.ClearData(testutils.DB)
db := testutils.InitMemoryDB(t)
// Setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
a := app.NewTest()
a.DB = db
a.Clock = clock.NewMock()
server := MustNewServer(t, &a)
defer server.Close()
user := testutils.SetupUserData()
testutils.SetupAccountData(user, "alice@test.com", "pass1234")
user := testutils.SetupUserData(db)
testutils.SetupAccountData(db, user, "alice@test.com", "pass1234")
testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn")
testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn")
b1 := database.Book{
UUID: b1UUID,
UserID: user.ID,
Label: "css",
}
testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1")
testutils.MustExec(t, db.Save(&b1), "preparing b1")
b2 := database.Book{
UUID: b2UUID,
UserID: user.ID,
Label: "js",
}
testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2")
testutils.MustExec(t, db.Save(&b2), "preparing b2")
note := database.Note{
UserID: user.ID,
UUID: tc.noteUUID,
UserID: user.ID,
BookUUID: tc.noteBookUUID,
Body: tc.noteBody,
Deleted: tc.noteDeleted,
Public: tc.notePublic,
}
testutils.MustExec(t, testutils.DB.Save(&note), "preparing note")
testutils.MustExec(t, db.Save(&note), "preparing note")
// Execute
var req *http.Request
@ -730,7 +744,7 @@ func TestUpdateNote(t *testing.T) {
endpoint := fmt.Sprintf("/api/v3/notes/%s", note.UUID)
req = testutils.MakeReq(server.URL, "PATCH", endpoint, tc.payload.ToJSON(t))
res := testutils.HTTPAuthDo(t, req, user)
res := testutils.HTTPAuthDo(t, db, req, user)
// Test
assert.StatusCodeEquals(t, res, http.StatusOK, "status code mismatch for test case")
@ -739,11 +753,11 @@ func TestUpdateNote(t *testing.T) {
var noteRecord database.Note
var userRecord database.User
var noteCount, bookCount int64
testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(&noteRecord), "finding note")
testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record")
testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books")
testutils.MustExec(t, db.Model(&database.Note{}).Count(&noteCount), "counting notes")
testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(&noteRecord), "finding note")
testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book")
testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record")
assert.Equalf(t, bookCount, int64(2), "book count mismatch")
assert.Equalf(t, noteCount, int64(1), "note count mismatch")

View file

@ -48,25 +48,25 @@ func NewWebRoutes(a *app.App, c *Controllers) []Route {
redirectGuest := &mw.AuthParams{RedirectGuestsToLogin: true}
ret := []Route{
{"GET", "/", mw.Auth(a, c.Users.Settings, redirectGuest), true},
{"GET", "/about", mw.Auth(a, c.Users.About, redirectGuest), true},
{"GET", "/login", mw.GuestOnly(a, c.Users.NewLogin), true},
{"POST", "/login", mw.GuestOnly(a, c.Users.Login), true},
{"GET", "/", mw.Auth(a.DB, c.Users.Settings, redirectGuest), true},
{"GET", "/about", mw.Auth(a.DB, c.Users.About, redirectGuest), true},
{"GET", "/login", mw.GuestOnly(a.DB, c.Users.NewLogin), true},
{"POST", "/login", mw.GuestOnly(a.DB, c.Users.Login), true},
{"POST", "/logout", c.Users.Logout, true},
{"GET", "/password-reset", c.Users.PasswordResetView.ServeHTTP, true},
{"PATCH", "/password-reset", c.Users.PasswordReset, true},
{"GET", "/password-reset/{token}", c.Users.PasswordResetConfirm, true},
{"POST", "/reset-token", c.Users.CreateResetToken, true},
{"POST", "/verification-token", mw.Auth(a, c.Users.CreateEmailVerificationToken, redirectGuest), true},
{"GET", "/verify-email/{token}", mw.Auth(a, c.Users.VerifyEmail, redirectGuest), true},
{"PATCH", "/account/profile", mw.Auth(a, c.Users.ProfileUpdate, nil), true},
{"PATCH", "/account/password", mw.Auth(a, c.Users.PasswordUpdate, nil), true},
{"POST", "/verification-token", mw.Auth(a.DB, c.Users.CreateEmailVerificationToken, redirectGuest), true},
{"GET", "/verify-email/{token}", mw.Auth(a.DB, c.Users.VerifyEmail, redirectGuest), true},
{"PATCH", "/account/profile", mw.Auth(a.DB, c.Users.ProfileUpdate, nil), true},
{"PATCH", "/account/password", mw.Auth(a.DB, c.Users.PasswordUpdate, nil), true},
{"GET", "/health", c.Health.Index, true},
}
if !a.Config.DisableRegistration {
if !a.DisableRegistration {
ret = append(ret, Route{"GET", "/join", c.Users.New, true})
ret = append(ret, Route{"POST", "/join", c.Users.Create, true})
}
@ -76,28 +76,25 @@ func NewWebRoutes(a *app.App, c *Controllers) []Route {
// NewAPIRoutes returns a new api routes
func NewAPIRoutes(a *app.App, c *Controllers) []Route {
proOnly := mw.AuthParams{ProOnly: true}
return []Route{
// v3
{"GET", "/v3/sync/fragment", mw.Cors(mw.Auth(a, c.Sync.GetSyncFragment, &proOnly)), false},
{"GET", "/v3/sync/state", mw.Cors(mw.Auth(a, c.Sync.GetSyncState, &proOnly)), false},
{"POST", "/v3/signin", mw.Cors(c.Users.V3Login), true},
{"POST", "/v3/signout", mw.Cors(c.Users.V3Logout), true},
{"OPTIONS", "/v3/signout", mw.Cors(c.Users.logoutOptions), true},
{"GET", "/v3/notes", mw.Cors(mw.Auth(a, c.Notes.V3Index, nil)), true},
{"GET", "/v3/sync/fragment", mw.Auth(a.DB, c.Sync.GetSyncFragment, nil), false},
{"GET", "/v3/sync/state", mw.Auth(a.DB, c.Sync.GetSyncState, nil), false},
{"POST", "/v3/signin", c.Users.V3Login, true},
{"POST", "/v3/signout", c.Users.V3Logout, true},
{"OPTIONS", "/v3/signout", c.Users.logoutOptions, true},
{"GET", "/v3/notes", mw.Auth(a.DB, c.Notes.V3Index, nil), true},
{"GET", "/v3/notes/{noteUUID}", c.Notes.V3Show, true},
{"POST", "/v3/notes", mw.Cors(mw.Auth(a, c.Notes.V3Create, nil)), true},
{"DELETE", "/v3/notes/{noteUUID}", mw.Cors(mw.Auth(a, c.Notes.V3Delete, nil)), true},
{"PATCH", "/v3/notes/{noteUUID}", mw.Cors(mw.Auth(a, c.Notes.V3Update, nil)), true},
{"OPTIONS", "/v3/notes", mw.Cors(c.Notes.IndexOptions), true},
{"GET", "/v3/books", mw.Cors(mw.Auth(a, c.Books.V3Index, nil)), true},
{"GET", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Show, nil)), true},
{"POST", "/v3/books", mw.Cors(mw.Auth(a, c.Books.V3Create, nil)), true},
{"PATCH", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Update, nil)), true},
{"DELETE", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Delete, nil)), true},
{"OPTIONS", "/v3/books", mw.Cors(c.Books.IndexOptions), true},
{"POST", "/v3/notes", mw.Auth(a.DB, c.Notes.V3Create, nil), true},
{"DELETE", "/v3/notes/{noteUUID}", mw.Auth(a.DB, c.Notes.V3Delete, nil), true},
{"PATCH", "/v3/notes/{noteUUID}", mw.Auth(a.DB, c.Notes.V3Update, nil), true},
{"OPTIONS", "/v3/notes", c.Notes.IndexOptions, true},
{"GET", "/v3/books", mw.Auth(a.DB, c.Books.V3Index, nil), true},
{"GET", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Show, nil), true},
{"POST", "/v3/books", mw.Auth(a.DB, c.Books.V3Create, nil), true},
{"PATCH", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Update, nil), true},
{"DELETE", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Delete, nil), true},
{"OPTIONS", "/v3/books", c.Books.IndexOptions, true},
}
}

View file

@ -25,7 +25,6 @@ import (
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/testutils"
)
@ -56,10 +55,11 @@ func TestNotSupportedVersions(t *testing.T) {
}
// setup
server := MustNewServer(t, &app.App{
Clock: clock.NewMock(),
Config: config.Config{},
})
db := testutils.InitMemoryDB(t)
a := app.NewTest()
a.Clock = clock.NewMock()
a.DB = db
server := MustNewServer(t, &a)
defer server.Close()
for _, tc := range testCases {

View file

@ -27,9 +27,9 @@ import (
)
// MustNewServer is a test utility function to initialize a new server
// with the given app paratmers
func MustNewServer(t *testing.T, appParams *app.App) *httptest.Server {
server, err := NewServer(appParams)
// with the given app
func MustNewServer(t *testing.T, a *app.App) *httptest.Server {
server, err := NewServer(a)
if err != nil {
t.Fatal(errors.Wrap(err, "initializing router"))
}
@ -37,16 +37,14 @@ func MustNewServer(t *testing.T, appParams *app.App) *httptest.Server {
return server
}
func NewServer(appParams *app.App) (*httptest.Server, error) {
a := app.NewTest(appParams)
ctl := New(&a)
func NewServer(a *app.App) (*httptest.Server, error) {
ctl := New(a)
rc := RouteConfig{
WebRoutes: NewWebRoutes(&a, ctl),
APIRoutes: NewAPIRoutes(&a, ctl),
WebRoutes: NewWebRoutes(a, ctl),
APIRoutes: NewAPIRoutes(a, ctl),
Controllers: ctl,
}
r, err := NewRouter(&a, rc)
r, err := NewRouter(a, rc)
if err != nil {
return nil, errors.Wrap(err, "initializing router")
}

File diff suppressed because it is too large Load diff

View file

@ -23,8 +23,6 @@ const (
TokenTypeResetPassword = "reset_password"
// TokenTypeEmailVerification is a type of a token for verifying email
TokenTypeEmailVerification = "email_verification"
// TokenTypeEmailPreference is a type of a token for updating email preference
TokenTypeEmailPreference = "email_preference"
)
const (

View file

@ -19,9 +19,11 @@
package database
import (
"github.com/dnote/dnote/pkg/server/config"
"os"
"path/filepath"
"github.com/pkg/errors"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
@ -32,18 +34,12 @@ var (
// InitSchema migrates database schema to reflect the latest model definition
func InitSchema(db *gorm.DB) {
if err := db.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil {
panic(err)
}
if err := db.AutoMigrate(
&User{},
&Account{},
&Book{},
&Note{},
&Notification{},
&Token{},
&EmailPreference{},
&Session{},
); err != nil {
panic(err)
@ -51,8 +47,14 @@ func InitSchema(db *gorm.DB) {
}
// Open initializes the database connection
func Open(c config.Config) *gorm.DB {
db, err := gorm.Open(postgres.Open(c.DB.GetConnectionStr()), &gorm.Config{})
func Open(dbPath string) *gorm.DB {
// Create directory if it doesn't exist
dir := filepath.Dir(dbPath)
if err := os.MkdirAll(dir, 0755); err != nil {
panic(errors.Wrapf(err, "creating database directory at %s", dir))
}
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
panic(errors.Wrap(err, "opening database conection"))
}

View file

@ -19,34 +19,167 @@
package database
import (
"log"
"net/http"
"fmt"
"io/fs"
"sort"
"strings"
"github.com/dnote/dnote/pkg/server/database/migrations"
"gorm.io/gorm"
"github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
"github.com/rubenv/sql-migrate"
"gorm.io/gorm"
)
// Migrate runs the migrations
func Migrate(db *gorm.DB) error {
migrations := &migrate.HttpFileSystemMigrationSource{
FileSystem: http.FileSystem(http.FS(migrations.Files)),
type migrationFile struct {
filename string
version int
}
// validateMigrationFilename checks if filename follows format: NNN-description.sql
func validateMigrationFilename(name string) error {
// Check .sql extension
if !strings.HasSuffix(name, ".sql") {
return errors.Errorf("invalid migration filename: must end with .sql")
}
migrate.SetTable(MigrationTableName)
sqlDB, err := db.DB()
if err != nil {
return errors.Wrap(err, "getting underlying sql.DB")
name = strings.TrimSuffix(name, ".sql")
parts := strings.SplitN(name, "-", 2)
if len(parts) != 2 {
return errors.Errorf("invalid migration filename: must be NNN-description.sql")
}
n, err := migrate.Exec(sqlDB, "postgres", migrations, migrate.Up)
if err != nil {
return errors.Wrap(err, "running migrations")
version, description := parts[0], parts[1]
// Validate version is 3 digits
if len(version) != 3 {
return errors.Errorf("invalid migration filename: version must be 3 digits, got %s", version)
}
for _, c := range version {
if c < '0' || c > '9' {
return errors.Errorf("invalid migration filename: version must be numeric, got %s", version)
}
}
log.Printf("Performed %d migrations", n)
// Validate description is not empty
if description == "" {
return errors.Errorf("invalid migration filename: description is required")
}
return nil
}
// Migrate runs the migrations using the embedded migration files
func Migrate(db *gorm.DB) error {
return migrate(db, migrations.Files)
}
// getMigrationFiles reads, validates, and sorts migration files
func getMigrationFiles(fsys fs.FS) ([]migrationFile, error) {
entries, err := fs.ReadDir(fsys, ".")
if err != nil {
return nil, errors.Wrap(err, "reading migration directory")
}
var migrations []migrationFile
seen := make(map[int]string)
for _, e := range entries {
name := e.Name()
if err := validateMigrationFilename(name); err != nil {
return nil, err
}
// Parse version
var v int
fmt.Sscanf(name, "%d", &v)
// Check for duplicate version numbers
if existing, found := seen[v]; found {
return nil, errors.Errorf("duplicate migration version %d: %s and %s", v, existing, name)
}
seen[v] = name
migrations = append(migrations, migrationFile{
filename: name,
version: v,
})
}
// Sort by version
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].version < migrations[j].version
})
return migrations, nil
}
// migrate runs migrations from the provided filesystem
func migrate(db *gorm.DB, fsys fs.FS) error {
if err := db.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`).Error; err != nil {
return errors.Wrap(err, "initializing migration table")
}
// Get current version
var version int
if err := db.Raw("SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(&version).Error; err != nil {
return errors.Wrap(err, "reading current version")
}
// Read and validate migration files
migrations, err := getMigrationFiles(fsys)
if err != nil {
return err
}
var filenames []string
for _, m := range migrations {
filenames = append(filenames, m.filename)
}
log.WithFields(log.Fields{
"version": version,
}).Info("Database schema version.")
log.WithFields(log.Fields{
"files": filenames,
}).Debug("Database migration files.")
// Apply pending migrations
for _, m := range migrations {
if m.version <= version {
continue
}
log.WithFields(log.Fields{
"file": m.filename,
}).Info("Applying migration.")
sql, err := fs.ReadFile(fsys, m.filename)
if err != nil {
return errors.Wrapf(err, "reading migration file %s", m.filename)
}
if len(strings.TrimSpace(string(sql))) == 0 {
return errors.Errorf("migration file %s is empty", m.filename)
}
if err := db.Exec(string(sql)).Error; err != nil {
return fmt.Errorf("migration %s failed: %w", m.filename, err)
}
if err := db.Exec("INSERT INTO schema_migrations (version) VALUES (?)", m.version).Error; err != nil {
return errors.Wrapf(err, "recording migration %s", m.filename)
}
log.WithFields(log.Fields{
"file": m.filename,
}).Info("Migrate success.")
}
return nil
}

View file

@ -1,72 +0,0 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package main
import (
"flag"
"fmt"
"os"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/joho/godotenv"
"github.com/pkg/errors"
"github.com/rubenv/sql-migrate"
)
var (
migrationDir = flag.String("migrationDir", "../migrations", "the path to the directory with migraiton files")
)
func init() {
fmt.Println("Migrating Dnote database...")
// Load env
if os.Getenv("GO_ENV") != "PRODUCTION" {
if err := godotenv.Load("../../.env.dev"); err != nil {
panic(err)
}
}
}
func main() {
flag.Parse()
c := config.Load()
db := database.Open(c)
migrations := &migrate.FileMigrationSource{
Dir: *migrationDir,
}
migrate.SetTable("migrations")
sqlDB, err := db.DB()
if err != nil {
panic(errors.Wrap(err, "getting underlying sql.DB"))
}
n, err := migrate.Exec(sqlDB, "postgres", migrations, migrate.Up)
if err != nil {
panic(errors.Wrap(err, "executing migrations"))
}
fmt.Printf("Applied %d migrations\n", n)
}

View file

@ -0,0 +1,313 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package database
import (
"io/fs"
"testing"
"testing/fstest"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// unsortedFS wraps fstest.MapFS to return entries in reverse order
type unsortedFS struct {
fstest.MapFS
}
func (u unsortedFS) ReadDir(name string) ([]fs.DirEntry, error) {
entries, err := u.MapFS.ReadDir(name)
if err != nil {
return nil, err
}
// Reverse the entries to ensure they're not in sorted order
for i, j := 0, len(entries)-1; i < j; i, j = i+1, j-1 {
entries[i], entries[j] = entries[j], entries[i]
}
return entries, nil
}
// errorFS returns an error on ReadDir
type errorFS struct{}
func (e errorFS) Open(name string) (fs.File, error) {
return nil, fs.ErrNotExist
}
func (e errorFS) ReadDir(name string) ([]fs.DirEntry, error) {
return nil, fs.ErrPermission
}
func TestMigrate_createsSchemaTable(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
migrationsFs := fstest.MapFS{}
migrate(db, migrationsFs)
// Verify schema_migrations table exists
var count int64
if err := db.Raw("SELECT COUNT(*) FROM schema_migrations").Scan(&count).Error; err != nil {
t.Fatalf("schema_migrations table not found: %v", err)
}
}
func TestMigrate_idempotency(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
// Set up table before migration
if err := db.Exec("CREATE TABLE counter (value INTEGER)").Error; err != nil {
t.Fatalf("failed to create table: %v", err)
}
// Create migration that inserts a row
migrationsFs := fstest.MapFS{
"001-insert-data.sql": &fstest.MapFile{
Data: []byte("INSERT INTO counter (value) VALUES (100);"),
},
}
// Run migration first time
if err := migrate(db, migrationsFs); err != nil {
t.Fatalf("first migration failed: %v", err)
}
var count int64
if err := db.Raw("SELECT COUNT(*) FROM counter").Scan(&count).Error; err != nil {
t.Fatalf("failed to count rows: %v", err)
}
if count != 1 {
t.Errorf("expected 1 row, got %d", count)
}
// Run migration second time - it should not run the SQL again
if err := migrate(db, migrationsFs); err != nil {
t.Fatalf("second migration failed: %v", err)
}
if err := db.Raw("SELECT COUNT(*) FROM counter").Scan(&count).Error; err != nil {
t.Fatalf("failed to count rows: %v", err)
}
if count != 1 {
t.Errorf("migration ran twice: expected 1 row, got %d", count)
}
}
func TestMigrate_ordering(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
// Create table before migrations
if err := db.Exec("CREATE TABLE log (value INTEGER)").Error; err != nil {
t.Fatalf("failed to create table: %v", err)
}
// Create migrations with unsorted filesystem
migrationsFs := unsortedFS{
MapFS: fstest.MapFS{
"010-tenth.sql": &fstest.MapFile{
Data: []byte("INSERT INTO log (value) VALUES (3);"),
},
"001-first.sql": &fstest.MapFile{
Data: []byte("INSERT INTO log (value) VALUES (1);"),
},
"002-second.sql": &fstest.MapFile{
Data: []byte("INSERT INTO log (value) VALUES (2);"),
},
},
}
// Run migrations
if err := migrate(db, migrationsFs); err != nil {
t.Fatalf("migration failed: %v", err)
}
// Verify migrations ran in correct order (1, 2, 3)
var values []int
if err := db.Raw("SELECT value FROM log ORDER BY rowid").Scan(&values).Error; err != nil {
t.Fatalf("failed to query log: %v", err)
}
expected := []int{1, 2, 3}
if len(values) != len(expected) {
t.Fatalf("expected %d rows, got %d", len(expected), len(values))
}
for i, v := range values {
if v != expected[i] {
t.Errorf("row %d: expected value %d, got %d", i, expected[i], v)
}
}
}
func TestMigrate_duplicateVersion(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
// Create migrations with duplicate version numbers
migrationsFs := fstest.MapFS{
"001-first.sql": &fstest.MapFile{
Data: []byte("SELECT 1;"),
},
"001-second.sql": &fstest.MapFile{
Data: []byte("SELECT 2;"),
},
}
// Should return error for duplicate version
err = migrate(db, migrationsFs)
if err == nil {
t.Fatal("expected error for duplicate version numbers, got nil")
}
}
func TestMigrate_initTableError(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
// Close the database connection to cause exec to fail
sqlDB, _ := db.DB()
sqlDB.Close()
migrationsFs := fstest.MapFS{
"001-init.sql": &fstest.MapFile{
Data: []byte("SELECT 1;"),
},
}
// Should return error for table initialization failure
err = migrate(db, migrationsFs)
if err == nil {
t.Fatal("expected error for table initialization failure, got nil")
}
}
func TestMigrate_readDirError(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
// Use filesystem that fails on ReadDir
err = migrate(db, errorFS{})
if err == nil {
t.Fatal("expected error for ReadDir failure, got nil")
}
}
func TestMigrate_sqlError(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
// Create migration with invalid SQL
migrationsFs := fstest.MapFS{
"001-bad-sql.sql": &fstest.MapFile{
Data: []byte("INVALID SQL SYNTAX HERE;"),
},
}
// Should return error for SQL execution failure
err = migrate(db, migrationsFs)
if err == nil {
t.Fatal("expected error for invalid SQL, got nil")
}
}
func TestMigrate_emptyFile(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
tests := []struct {
name string
data string
wantErr bool
}{
{"completely empty", "", true},
{"only whitespace", " \n\t ", true},
{"only comments", "-- just a comment", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
migrationsFs := fstest.MapFS{
"001-empty.sql": &fstest.MapFile{
Data: []byte(tt.data),
},
}
err = migrate(db, migrationsFs)
if (err != nil) != tt.wantErr {
t.Errorf("error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestMigrate_invalidFilename(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open database: %v", err)
}
tests := []struct {
name string
filename string
wantErr bool
}{
{"valid format", "001-init.sql", false},
{"no leading zeros", "1-init.sql", true},
{"two digits", "01-init.sql", true},
{"no dash", "001init.sql", true},
{"no description", "001-.sql", true},
{"no extension", "001-init.", true},
{"wrong extension", "001-init.txt", true},
{"non-numeric version number", "0a1-init.sql", true},
{"underscore separator", "001_init.sql", true},
{"multiple dashes in description", "001-add-feature-v2.sql", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
migrationsFs := fstest.MapFS{
tt.filename: &fstest.MapFile{
Data: []byte("SELECT 1;"),
},
}
err := migrate(db, migrationsFs)
if (err != nil) != tt.wantErr {
t.Errorf("error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -0,0 +1,18 @@
-- Create FTS5 virtual table for full-text search on notes
CREATE VIRTUAL TABLE IF NOT EXISTS notes_fts USING fts5(
content=notes,
body,
tokenize="porter unicode61 categories 'L* N* Co Ps Pe'"
);
-- Create triggers to keep notes_fts in sync with notes
CREATE TRIGGER IF NOT EXISTS notes_insert AFTER INSERT ON notes BEGIN
INSERT INTO notes_fts(rowid, body) VALUES (new.rowid, new.body);
END;
CREATE TRIGGER IF NOT EXISTS notes_delete AFTER DELETE ON notes BEGIN
INSERT INTO notes_fts(notes_fts, rowid, body) VALUES ('delete', old.rowid, old.body);
END;
CREATE TRIGGER IF NOT EXISTS notes_update AFTER UPDATE ON notes BEGIN
INSERT INTO notes_fts(notes_fts, rowid, body) VALUES ('delete', old.rowid, old.body);
INSERT INTO notes_fts(rowid, body) VALUES (new.rowid, new.body);
END;

View file

@ -1,41 +0,0 @@
-- +migrate Up
-- Configure full text search
CREATE TEXT SEARCH DICTIONARY english_nostop (
Template = snowball,
Language = english
);
CREATE TEXT SEARCH CONFIGURATION public.english_nostop ( COPY = pg_catalog.english );
ALTER TEXT SEARCH CONFIGURATION public.english_nostop
ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH english_nostop;
-- Create a trigger
-- +migrate StatementBegin
CREATE OR REPLACE FUNCTION note_tsv_trigger() RETURNS trigger AS $$
begin
new.tsv := setweight(to_tsvector('english_nostop', new.body), 'A');
return new;
end
$$ LANGUAGE plpgsql;
DROP TRIGGER IF EXISTS tsvectorupdate ON notes;
CREATE TRIGGER tsvectorupdate
BEFORE INSERT OR UPDATE ON notes
FOR EACH ROW EXECUTE PROCEDURE note_tsv_trigger();
-- +migrate StatementEnd
-- index tsv
CREATE INDEX IF NOT EXISTS idx_notes_tsv
ON notes
USING gin(tsv);
-- initialize tsv
UPDATE notes
SET tsv = setweight(to_tsvector('english_nostop', notes.body), 'A')
WHERE notes.encrypted = false;
-- +migrate Down

View file

@ -1,8 +0,0 @@
-- this migration is noop because repetition_rules have been removed
-- create-weekly-repetition.sql creates the default repetition rules for the users
-- that used to have the weekly email digest on Friday 20:00 UTC
-- +migrate Up
-- +migrate Down

View file

@ -1,9 +0,0 @@
-- this migration is noop because digests have been removed
-- populate-digest-version.sql populates the `version` column for the digests
-- by assigining an incremental integer scoped to a repetition rule that each
-- digest belongs, ordered by created_at timestamp of the digests.
-- +migrate Up
-- +migrate Down

View file

@ -1,5 +0,0 @@
-- this migration is noop because digests have been removed
-- +migrate Up
-- +migrate Down

View file

@ -1,8 +0,0 @@
-- this migration is noop because digests have been removed
-- -use-id-in-digest-notes-joining-table.sql replaces uuids with ids
-- as foreign keys in the digest_notes joining table.
-- +migrate Up
-- +migrate Down

View file

@ -1,8 +0,0 @@
-- this migration is noop because digests have been removed
-- delete-outdated-digests.sql deletes digests that do not belong to any repetition rules,
-- along with digest_notes associations.
-- +migrate Up
-- +migrate Down

View file

@ -1,9 +0,0 @@
-- remove-billing-columns.sql drops billing related columns that are now obsolete.
-- +migrate Up
ALTER TABLE users DROP COLUMN IF EXISTS stripe_customer_id;
ALTER TABLE users DROP COLUMN IF EXISTS billing_country;
-- +migrate Down

View file

@ -25,14 +25,14 @@ import (
// Model is the base model definition
type Model struct {
ID int `gorm:"primaryKey" json:"-"`
CreatedAt time.Time `json:"created_at" gorm:"default:now()"`
UpdatedAt time.Time `json:"updated_at"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// Book is a model for a book
type Book struct {
Model
UUID string `json:"uuid" gorm:"uniqueIndex;type:uuid;default:uuid_generate_v4()"`
UUID string `json:"uuid" gorm:"uniqueIndex;type:text"`
UserID int `json:"user_id" gorm:"index"`
Label string `json:"label" gorm:"index"`
Notes []Note `json:"notes" gorm:"foreignKey:BookUUID;references:UUID"`
@ -46,15 +46,14 @@ type Book struct {
// Note is a model for a note
type Note struct {
Model
UUID string `json:"uuid" gorm:"index;type:uuid;default:uuid_generate_v4()"`
UUID string `json:"uuid" gorm:"index;type:text"`
Book Book `json:"book" gorm:"foreignKey:BookUUID;references:UUID"`
User User `json:"user"`
UserID int `json:"user_id" gorm:"index"`
BookUUID string `json:"book_uuid" gorm:"index;type:uuid"`
BookUUID string `json:"book_uuid" gorm:"index;type:text"`
Body string `json:"content"`
AddedOn int64 `json:"added_on"`
EditedOn int64 `json:"edited_on"`
TSV string `json:"-" gorm:"type:tsvector"`
Public bool `json:"public" gorm:"default:false"`
USN int `json:"-" gorm:"index"`
Deleted bool `json:"-" gorm:"default:false"`
@ -65,11 +64,10 @@ type Note struct {
// User is a model for a user
type User struct {
Model
UUID string `json:"uuid" gorm:"type:uuid;index;default:uuid_generate_v4()"`
UUID string `json:"uuid" gorm:"type:text;index"`
Account Account `gorm:"foreignKey:UserID"`
LastLoginAt *time.Time `json:"-"`
MaxUSN int `json:"-" gorm:"default:0"`
Cloud bool `json:"-" gorm:"default:false"`
}
// Account is a model for an account
@ -90,21 +88,6 @@ type Token struct {
UsedAt *time.Time
}
// Notification is the learning notification sent to the user
type Notification struct {
Model
Type string
UserID int `gorm:"index"`
}
// EmailPreference is a preference per user for receiving email communication
type EmailPreference struct {
Model
UserID int `gorm:"index" json:"-"`
InactiveReminder bool `json:"inactive_reminder" gorm:"default:false"`
ProductUpdate bool `json:"product_update" gorm:"default:true"`
}
// Session represents a user session
type Session struct {
Model

View file

@ -1,21 +0,0 @@
#!/usr/bin/env bash
# create-migration.sh creates a new SQL migration file for the
# server side Postgres database using the sql-migrate tool.
set -eux
is_command () {
command -v "$1" >/dev/null 2>&1;
}
if ! is_command sql-migrate; then
echo "sql-migrate is not found. Please run install-sql-migrate.sh"
exit 1
fi
if [ "$#" == 0 ]; then
echo "filename not provided"
exit 1
fi
filename=$1
sql-migrate new -config=./sql-migrate.yml "$filename"

View file

@ -1,3 +0,0 @@
#!/usr/bin/env bash
go get -v github.com/rubenv/sql-migrate/...

View file

@ -1,8 +0,0 @@
# A configuration for sql-migrate tool for generating migrations
# using `sql-migrate new`. This file is not actually used for running
# migrations because we run them programmatically.
development:
dialect: postgres
datasource: dbname=dnote sslmode=disable
dir: ./migrations

View file

@ -1,127 +0,0 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package job
import (
slog "log"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/mailer"
"gorm.io/gorm"
"github.com/pkg/errors"
"github.com/robfig/cron"
)
var (
// ErrEmptyDB is an error for missing database connection in the app configuration
ErrEmptyDB = errors.New("No database connection was provided")
// ErrEmptyClock is an error for missing clock in the app configuration
ErrEmptyClock = errors.New("No clock was provided")
// ErrEmptyWebURL is an error for missing WebURL content in the app configuration
ErrEmptyWebURL = errors.New("No WebURL was provided")
// ErrEmptyEmailTemplates is an error for missing EmailTemplates content in the app configuration
ErrEmptyEmailTemplates = errors.New("No EmailTemplate store was provided")
// ErrEmptyEmailBackend is an error for missing EmailBackend content in the app configuration
ErrEmptyEmailBackend = errors.New("No EmailBackend was provided")
)
// Runner is a configuration for job
type Runner struct {
DB *gorm.DB
Clock clock.Clock
EmailTmpl mailer.Templates
EmailBackend mailer.Backend
Config config.Config
}
// NewRunner returns a new runner
func NewRunner(db *gorm.DB, c clock.Clock, t mailer.Templates, b mailer.Backend, config config.Config) (Runner, error) {
ret := Runner{
DB: db,
EmailTmpl: t,
EmailBackend: b,
Clock: c,
Config: config,
}
if err := ret.validate(); err != nil {
return Runner{}, errors.Wrap(err, "validating runner configuration")
}
return ret, nil
}
func (r *Runner) validate() error {
if r.DB == nil {
return ErrEmptyDB
}
if r.Clock == nil {
return ErrEmptyClock
}
if r.EmailTmpl == nil {
return ErrEmptyEmailTemplates
}
if r.EmailBackend == nil {
return ErrEmptyEmailBackend
}
if r.Config.WebURL == "" {
return ErrEmptyWebURL
}
return nil
}
func scheduleJob(c *cron.Cron, spec string, cmd func()) {
s, err := cron.ParseStandard(spec)
if err != nil {
panic(errors.Wrap(err, "parsing schedule"))
}
c.Schedule(s, cron.FuncJob(cmd))
}
func (r *Runner) schedule(ch chan error) {
// Schedule jobs
cr := cron.New()
cr.Start()
ch <- nil
// Block forever
select {}
}
// Do starts the background tasks in a separate goroutine that runs forever
func (r *Runner) Do() error {
// validate
if err := r.validate(); err != nil {
return errors.Wrap(err, "validating job configurations")
}
ch := make(chan error)
go r.schedule(ch)
if err := <-ch; err != nil {
return errors.Wrap(err, "scheduling jobs")
}
slog.Println("Started background tasks")
return nil
}

View file

@ -1,104 +0,0 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package job
import (
"fmt"
"testing"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/mailer"
"github.com/dnote/dnote/pkg/server/testutils"
"gorm.io/gorm"
"github.com/pkg/errors"
)
func TestNewRunner(t *testing.T) {
testCases := []struct {
db *gorm.DB
clock clock.Clock
emailTmpl mailer.Templates
emailBackend mailer.Backend
webURL string
expectedErr error
}{
{
db: &gorm.DB{},
clock: clock.NewMock(),
emailTmpl: mailer.Templates{},
emailBackend: &testutils.MockEmailbackendImplementation{},
webURL: "http://mock.url",
expectedErr: nil,
},
{
db: nil,
clock: clock.NewMock(),
emailTmpl: mailer.Templates{},
emailBackend: &testutils.MockEmailbackendImplementation{},
webURL: "http://mock.url",
expectedErr: ErrEmptyDB,
},
{
db: &gorm.DB{},
clock: nil,
emailTmpl: mailer.Templates{},
emailBackend: &testutils.MockEmailbackendImplementation{},
webURL: "http://mock.url",
expectedErr: ErrEmptyClock,
},
{
db: &gorm.DB{},
clock: clock.NewMock(),
emailTmpl: nil,
emailBackend: &testutils.MockEmailbackendImplementation{},
webURL: "http://mock.url",
expectedErr: ErrEmptyEmailTemplates,
},
{
db: &gorm.DB{},
clock: clock.NewMock(),
emailTmpl: mailer.Templates{},
emailBackend: nil,
webURL: "http://mock.url",
expectedErr: ErrEmptyEmailBackend,
},
{
db: &gorm.DB{},
clock: clock.NewMock(),
emailTmpl: mailer.Templates{},
emailBackend: &testutils.MockEmailbackendImplementation{},
webURL: "",
expectedErr: ErrEmptyWebURL,
},
}
for idx, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) {
c := config.Load()
c.WebURL = tc.webURL
_, err := NewRunner(tc.db, tc.clock, tc.emailTmpl, tc.emailBackend, c)
assert.Equal(t, errors.Cause(err), tc.expectedErr, "error mismatch")
})
}
}

View file

@ -32,9 +32,19 @@ const (
fieldKeyTimestamp = "ts"
fieldKeyUnixTimestamp = "ts_unix"
levelInfo = "info"
levelWarn = "warn"
levelError = "error"
// LevelDebug represents debug log level
LevelDebug = "debug"
// LevelInfo represents info log level
LevelInfo = "info"
// LevelWarn represents warn log level
LevelWarn = "warn"
// LevelError represents error log level
LevelError = "error"
)
var (
// currentLevel is the currently configured log level
currentLevel = LevelInfo
)
// Fields represents a set of information to be included in the log
@ -58,19 +68,50 @@ func WithFields(fields Fields) Entry {
return newEntry(fields)
}
// SetLevel sets the global log level
func SetLevel(level string) {
currentLevel = level
}
// levelPriority returns a numeric priority for comparison
func levelPriority(level string) int {
switch level {
case LevelDebug:
return 0
case LevelInfo:
return 1
case LevelWarn:
return 2
case LevelError:
return 3
default:
return 1
}
}
// shouldLog returns true if the given level should be logged based on currentLevel
func shouldLog(level string) bool {
return levelPriority(level) >= levelPriority(currentLevel)
}
// Debug logs the given entry at a debug level
func (e Entry) Debug(msg string) {
e.write(LevelDebug, msg)
}
// Info logs the given entry at an info level
func (e Entry) Info(msg string) {
e.write(levelInfo, msg)
e.write(LevelInfo, msg)
}
// Warn logs the given entry at a warning level
func (e Entry) Warn(msg string) {
e.write(levelWarn, msg)
e.write(LevelWarn, msg)
}
// Error logs the given entry at an error level
func (e Entry) Error(msg string) {
e.write(levelError, msg)
e.write(LevelError, msg)
}
// ErrorWrap logs the given entry with the error message annotated by the given message
@ -106,6 +147,10 @@ func (e Entry) formatJSON(level, msg string) []byte {
}
func (e Entry) write(level, msg string) {
if !shouldLog(level) {
return
}
serialized := e.formatJSON(level, msg)
_, err := fmt.Fprintln(os.Stderr, string(serialized))
@ -114,6 +159,11 @@ func (e Entry) write(level, msg string) {
}
}
// Debug logs a debug message without additional fields
func Debug(msg string) {
newEntry(Fields{}).Debug(msg)
}
// Info logs an info message without additional fields
func Info(msg string) {
newEntry(Fields{}).Info(msg)

View file

@ -19,11 +19,10 @@
package mailer
import (
"fmt"
"log"
"os"
"strconv"
"github.com/dnote/dnote/pkg/server/log"
"github.com/pkg/errors"
"gopkg.in/gomail.v2"
)
@ -36,9 +35,21 @@ type Backend interface {
Queue(subject, from string, to []string, contentType, body string) error
}
// SimpleBackendImplementation is an implementation of the Backend
// EmailDialer is an interface for sending email messages
type EmailDialer interface {
DialAndSend(m ...*gomail.Message) error
}
// gomailDialer wraps gomail.Dialer to implement EmailDialer interface
type gomailDialer struct {
*gomail.Dialer
}
// DefaultBackend is an implementation of the Backend
// that sends an email without queueing.
type SimpleBackendImplementation struct {
type DefaultBackend struct {
Dialer EmailDialer
Enabled bool
}
type dialerParams struct {
@ -73,13 +84,31 @@ func getSMTPParams() (*dialerParams, error) {
return p, nil
}
// NewDefaultBackend creates a default backend
func NewDefaultBackend(enabled bool) (*DefaultBackend, error) {
p, err := getSMTPParams()
if err != nil {
return nil, err
}
d := gomail.NewDialer(p.Host, p.Port, p.Username, p.Password)
return &DefaultBackend{
Dialer: &gomailDialer{Dialer: d},
Enabled: enabled,
}, nil
}
// Queue is an implementation of Backend.Queue.
func (b *SimpleBackendImplementation) Queue(subject, from string, to []string, contentType, body string) error {
// If not production, never actually send an email
if os.Getenv("GO_ENV") != "PRODUCTION" {
log.Println("Not sending email because Dnote is not running in a production environment.")
log.Printf("Subject: %s, to: %s, from: %s", subject, to, from)
fmt.Println(body)
func (b *DefaultBackend) Queue(subject, from string, to []string, contentType, body string) error {
// If not enabled, just log the email
if !b.Enabled {
log.WithFields(log.Fields{
"subject": subject,
"to": to,
"from": from,
"body": body,
}).Info("Not sending email because email backend is not configured.")
return nil
}
@ -89,13 +118,7 @@ func (b *SimpleBackendImplementation) Queue(subject, from string, to []string, c
m.SetHeader("Subject", subject)
m.SetBody(contentType, body)
p, err := getSMTPParams()
if err != nil {
return errors.Wrap(err, "getting dialer params")
}
d := gomail.NewPlainDialer(p.Host, p.Port, p.Username, p.Password)
if err := d.DialAndSend(m); err != nil {
if err := b.Dialer.DialAndSend(m); err != nil {
return errors.Wrap(err, "dialing and sending email")
}

View file

@ -0,0 +1,107 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package mailer
import (
"testing"
"gopkg.in/gomail.v2"
)
type mockDialer struct {
sentMessages []*gomail.Message
err error
}
func (m *mockDialer) DialAndSend(msgs ...*gomail.Message) error {
m.sentMessages = append(m.sentMessages, msgs...)
return m.err
}
func TestDefaultBackendQueue(t *testing.T) {
t.Run("enabled sends email", func(t *testing.T) {
mock := &mockDialer{}
backend := &DefaultBackend{
Dialer: mock,
Enabled: true,
}
err := backend.Queue("Test Subject", "alice@example.com", []string{"bob@example.com"}, "text/plain", "Test body")
if err != nil {
t.Fatalf("Queue failed: %v", err)
}
if len(mock.sentMessages) != 1 {
t.Errorf("expected 1 message sent, got %d", len(mock.sentMessages))
}
})
t.Run("disabled does not send email", func(t *testing.T) {
mock := &mockDialer{}
backend := &DefaultBackend{
Dialer: mock,
Enabled: false,
}
err := backend.Queue("Test Subject", "alice@example.com", []string{"bob@example.com"}, "text/plain", "Test body")
if err != nil {
t.Fatalf("Queue failed: %v", err)
}
if len(mock.sentMessages) != 0 {
t.Errorf("expected 0 messages sent when disabled, got %d", len(mock.sentMessages))
}
})
}
func TestNewDefaultBackend(t *testing.T) {
t.Run("with all env vars set", func(t *testing.T) {
t.Setenv("SmtpHost", "smtp.example.com")
t.Setenv("SmtpPort", "587")
t.Setenv("SmtpUsername", "user@example.com")
t.Setenv("SmtpPassword", "secret")
backend, err := NewDefaultBackend(true)
if err != nil {
t.Fatalf("NewDefaultBackend failed: %v", err)
}
if backend.Enabled != true {
t.Errorf("expected Enabled to be true, got %v", backend.Enabled)
}
if backend.Dialer == nil {
t.Error("expected Dialer to be set")
}
})
t.Run("missing SMTP config returns error", func(t *testing.T) {
t.Setenv("SmtpHost", "")
t.Setenv("SmtpPort", "")
t.Setenv("SmtpUsername", "")
t.Setenv("SmtpPassword", "")
_, err := NewDefaultBackend(true)
if err == nil {
t.Error("expected error when SMTP not configured")
}
if err != ErrSMTPNotConfigured {
t.Errorf("expected ErrSMTPNotConfigured, got %v", err)
}
})
}

View file

@ -21,13 +21,11 @@ package mailer
import (
"bytes"
"embed"
"fmt"
htemplate "html/template"
"io"
ttemplate "text/template"
"github.com/aymerick/douceur/inliner"
"github.com/dnote/dnote/pkg/server/mailer/templates"
"github.com/pkg/errors"
)
@ -40,13 +38,9 @@ var (
EmailTypeEmailVerification = "verify_email"
// EmailTypeWelcome represents an welcome email
EmailTypeWelcome = "welcome"
// EmailTypeInactiveReminder represents an inactivity reminder email
EmailTypeInactiveReminder = "inactive"
)
var (
// EmailKindHTML is the type of html email
EmailKindHTML = "text/html"
// EmailKindText is the type of text email
EmailKindText = "text/plain"
)
@ -60,9 +54,6 @@ type template interface {
// Templates holds the parsed email templates
type Templates map[string]template
//go:embed templates/src
var templateDir embed.FS
func getTemplateKey(name, kind string) string {
return fmt.Sprintf("%s.%s", name, kind)
}
@ -100,58 +91,21 @@ func NewTemplates() Templates {
if err != nil {
panic(errors.Wrap(err, "initializing password reset template"))
}
inactiveReminderText, err := initTextTmpl(EmailTypeInactiveReminder)
if err != nil {
panic(errors.Wrap(err, "initializing password reset template"))
}
T := Templates{}
T.set(EmailTypeResetPassword, EmailKindText, passwordResetText)
T.set(EmailTypeResetPasswordAlert, EmailKindText, passwordResetAlertText)
T.set(EmailTypeEmailVerification, EmailKindText, verifyEmailText)
T.set(EmailTypeWelcome, EmailKindText, welcomeText)
T.set(EmailTypeInactiveReminder, EmailKindText, inactiveReminderText)
return T
}
// initHTMLTmpl returns a template instance by parsing the template with the
// given name along with partials
func initHTMLTmpl(templateName string) (template, error) {
filename := fmt.Sprintf("templates/src/%s.html", templateName)
content, err := templateDir.ReadFile(filename)
if err != nil {
return nil, errors.Wrap(err, "reading template")
}
headerContent, err := templateDir.ReadFile("templates/header.html")
if err != nil {
return nil, errors.Wrap(err, "reading header template")
}
footerContent, err := templateDir.ReadFile("templates/footer.html")
if err != nil {
return nil, errors.Wrap(err, "reading footer template")
}
t := htemplate.New(templateName)
if _, err = t.Parse(string(content)); err != nil {
return nil, errors.Wrap(err, "parsing template")
}
if _, err = t.Parse(string(headerContent)); err != nil {
return nil, errors.Wrap(err, "parsing template")
}
if _, err = t.Parse(string(footerContent)); err != nil {
return nil, errors.Wrap(err, "parsing template")
}
return t, nil
}
// initTextTmpl returns a template instance by parsing the template with the given name
func initTextTmpl(templateName string) (template, error) {
filename := fmt.Sprintf("templates/src/%s.txt", templateName)
filename := fmt.Sprintf("%s.txt", templateName)
content, err := templateDir.ReadFile(filename)
content, err := templates.Files.ReadFile(filename)
if err != nil {
return nil, errors.Wrap(err, "reading template")
}
@ -165,7 +119,7 @@ func initTextTmpl(templateName string) (template, error) {
}
// Execute executes the template with the given name with the givn data
func (tmpl Templates) Execute(name, kind string, data interface{}) (string, error) {
func (tmpl Templates) Execute(name, kind string, data any) (string, error) {
t, err := tmpl.get(name, kind)
if err != nil {
return "", errors.Wrap(err, "getting template")
@ -176,15 +130,5 @@ func (tmpl Templates) Execute(name, kind string, data interface{}) (string, erro
return "", errors.Wrap(err, "executing the template")
}
// If HTML email, inline the CSS rules
if kind == EmailKindHTML {
html, err := inliner.Inline(buf.String())
if err != nil {
return "", errors.Wrap(err, "inlining the css rules")
}
return html, nil
}
return buf.String(), nil
}

View file

@ -26,6 +26,26 @@ import (
"github.com/pkg/errors"
)
func TestAllTemplatesInitialized(t *testing.T) {
tmpl := NewTemplates()
emailTypes := []string{
EmailTypeResetPassword,
EmailTypeResetPasswordAlert,
EmailTypeEmailVerification,
EmailTypeWelcome,
}
for _, emailType := range emailTypes {
t.Run(emailType, func(t *testing.T) {
_, err := tmpl.get(emailType, EmailKindText)
if err != nil {
t.Errorf("template %s not initialized: %v", emailType, err)
}
})
}
}
func TestEmailVerificationEmail(t *testing.T) {
testCases := []struct {
token string
@ -101,3 +121,79 @@ func TestResetPasswordEmail(t *testing.T) {
})
}
}
func TestWelcomeEmail(t *testing.T) {
testCases := []struct {
accountEmail string
webURL string
}{
{
accountEmail: "test@example.com",
webURL: "http://localhost:3000",
},
{
accountEmail: "user@example.org",
webURL: "http://localhost:3001",
},
}
tmpl := NewTemplates()
for _, tc := range testCases {
t.Run(fmt.Sprintf("with WebURL %s and email %s", tc.webURL, tc.accountEmail), func(t *testing.T) {
dat := WelcomeTmplData{
AccountEmail: tc.accountEmail,
WebURL: tc.webURL,
}
body, err := tmpl.Execute(EmailTypeWelcome, EmailKindText, dat)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
if ok := strings.Contains(body, tc.webURL); !ok {
t.Errorf("email body did not contain %s", tc.webURL)
}
if ok := strings.Contains(body, tc.accountEmail); !ok {
t.Errorf("email body did not contain %s", tc.accountEmail)
}
})
}
}
func TestResetPasswordAlertEmail(t *testing.T) {
testCases := []struct {
accountEmail string
webURL string
}{
{
accountEmail: "test@example.com",
webURL: "http://localhost:3000",
},
{
accountEmail: "user@example.org",
webURL: "http://localhost:3001",
},
}
tmpl := NewTemplates()
for _, tc := range testCases {
t.Run(fmt.Sprintf("with WebURL %s and email %s", tc.webURL, tc.accountEmail), func(t *testing.T) {
dat := EmailResetPasswordAlertTmplData{
AccountEmail: tc.accountEmail,
WebURL: tc.webURL,
}
body, err := tmpl.Execute(EmailTypeResetPasswordAlert, EmailKindText, dat)
if err != nil {
t.Fatal(errors.Wrap(err, "executing"))
}
if ok := strings.Contains(body, tc.webURL); !ok {
t.Errorf("email body did not contain %s", tc.webURL)
}
if ok := strings.Contains(body, tc.accountEmail); !ok {
t.Errorf("email body did not contain %s", tc.accountEmail)
}
})
}
}

View file

@ -1,12 +0,0 @@
DBHost=localhost
DBPort=5433
DBName=dnote
DBUser=postgres
DBPassword=
SmtpUsername=mock-SmtpUsername
SmtpPassword=mock-SmtpPassword
SmtpHost=mock-SmtpHost
WebURL=http://localhost:3000
DisableRegistration=false

View file

@ -1 +0,0 @@
templates

View file

@ -1,13 +0,0 @@
# templates
Email templates
* `/src` contains templates.
## Development
Run the server to develop templates locally.
```
./dev.sh
```

View file

@ -1,21 +0,0 @@
#!/usr/bin/env bash
set -eux
PID=""
function cleanup {
if [ "$PID" != "" ]; then
kill "$PID"
fi
}
trap cleanup EXIT
while true; do
go build main.go
./main &
PID=$!
inotifywait -r -e modify .
kill $PID
done

View file

@ -1,144 +0,0 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package main
import (
"log"
"net/http"
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/mailer"
"gorm.io/gorm"
"github.com/joho/godotenv"
_ "github.com/lib/pq"
)
func (c Context) passwordResetHandler(w http.ResponseWriter, r *http.Request) {
data := mailer.EmailResetPasswordTmplData{
AccountEmail: "alice@example.com",
Token: "testToken",
WebURL: "http://localhost:3000",
}
body, err := c.Tmpl.Execute(mailer.EmailTypeResetPassword, mailer.EmailKindText, data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write([]byte(body))
}
func (c Context) passwordResetAlertHandler(w http.ResponseWriter, r *http.Request) {
data := mailer.EmailResetPasswordAlertTmplData{
AccountEmail: "alice@example.com",
WebURL: "http://localhost:3000",
}
body, err := c.Tmpl.Execute(mailer.EmailTypeResetPasswordAlert, mailer.EmailKindText, data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write([]byte(body))
}
func (c Context) emailVerificationHandler(w http.ResponseWriter, r *http.Request) {
data := mailer.EmailVerificationTmplData{
Token: "testToken",
WebURL: "http://localhost:3000",
}
body, err := c.Tmpl.Execute(mailer.EmailTypeEmailVerification, mailer.EmailKindText, data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write([]byte(body))
}
func (c Context) welcomeHandler(w http.ResponseWriter, r *http.Request) {
data := mailer.WelcomeTmplData{
AccountEmail: "alice@example.com",
WebURL: "http://localhost:3000",
}
body, err := c.Tmpl.Execute(mailer.EmailTypeWelcome, mailer.EmailKindText, data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write([]byte(body))
}
func (c Context) inactiveHandler(w http.ResponseWriter, r *http.Request) {
data := mailer.InactiveReminderTmplData{
SampleNoteUUID: "some-uuid",
WebURL: "http://localhost:3000",
Token: "some-random-token",
}
body, err := c.Tmpl.Execute(mailer.EmailTypeInactiveReminder, mailer.EmailKindText, data)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write([]byte(body))
}
func (c Context) homeHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Email development server is running."))
}
func init() {
err := godotenv.Load(".env.dev")
if err != nil {
panic(err)
}
}
// Context is a context holding global information
type Context struct {
DB *gorm.DB
Tmpl mailer.Templates
}
func main() {
c := config.Load()
db := database.Open(c)
defer func() {
sqlDB, err := db.DB()
if err == nil {
sqlDB.Close()
}
}()
log.Println("Email template development server running on http://127.0.0.1:2300")
tmpl := mailer.NewTemplates()
ctx := Context{DB: db, Tmpl: tmpl}
http.HandleFunc("/", ctx.homeHandler)
http.HandleFunc("/email-verification", ctx.emailVerificationHandler)
http.HandleFunc("/password-reset", ctx.passwordResetHandler)
http.HandleFunc("/password-reset-alert", ctx.passwordResetAlertHandler)
http.HandleFunc("/welcome", ctx.welcomeHandler)
http.HandleFunc("/inactive-reminder", ctx.inactiveHandler)
log.Fatal(http.ListenAndServe(":2300", nil))
}

View file

@ -0,0 +1,5 @@
You are receiving this because you requested to reset the password of the '{{ .AccountEmail }}' Dnote account.
Please click on the following link, or paste this into your browser to complete the process:
{{ .WebURL }}/password-reset/{{ .Token }}

View file

@ -2,7 +2,7 @@ Hi,
This email is to notify you that the password for your Dnote account "{{ .AccountEmail }}" has changed.
If you did not initiate this password change, please notify us by replying, and reset your password at {{ .WebURL }}/password-reset
If you did not initiate this password change, reset your password at {{ .WebURL }}/password-reset.
Thanks.

View file

@ -1 +0,0 @@
CompileDaemon -directory=. -command="./templates" -include="*.html"

View file

@ -1,9 +0,0 @@
Hi, nothing has been added to your Dnote for some time.
What about revisiting one of your previous notes? {{ .WebURL }}/notes/{{ .SampleNoteUUID }}
You can add new notes at {{ .WebURL }}/new or using Dnote apps.
- Dnote team
UNSUBSCRIBE: {{ .WebURL }}/settings/notifications?token={{ .Token }}

View file

@ -1,9 +0,0 @@
You are receiving this because you (or someone else) requested to reset the password of the '{{ .AccountEmail }}' Dnote account.
Please click on the following link, or paste this into your browser to complete the process:
{{ .WebURL }}/password-reset/{{ .Token }}
You can reply to this message, if you have questions.
- Dnote team

View file

@ -16,20 +16,10 @@
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package app
// Package mailer provides a functionality to send emails
package templates
import (
"os"
"testing"
import "embed"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestMain(m *testing.M) {
testutils.InitTestDB()
code := m.Run()
testutils.ClearData(testutils.DB)
os.Exit(code)
}
//go:embed *.txt
var Files embed.FS

View file

@ -1,9 +1,5 @@
Hi.
Hi,
Welcome to Dnote! To verify your email, visit the following link:
{{ .WebURL }}/verify-email/{{ .Token }}
Thanks for using Dnote.
- Dnote team

View file

@ -10,7 +10,3 @@ If you ever forget your password, you can reset it at {{ .WebURL }}/password-res
SOURCE CODE
Dnote is open source and you can see the source code at https://github.com/dnote/dnote
Feel free to reply anytime. Thanks for using Dnote.
- Dnote team

View file

@ -21,19 +21,17 @@ package mailer
import (
"crypto/rand"
"encoding/base64"
"errors"
"github.com/dnote/dnote/pkg/server/database"
pkgErrors "github.com/pkg/errors"
"github.com/pkg/errors"
"gorm.io/gorm"
)
func generateRandomToken(bits int) (string, error) {
b := make([]byte, bits)
_, err := rand.Read(b)
if err != nil {
return "", pkgErrors.Wrap(err, "generating random bytes")
if _, err := rand.Read(b); err != nil {
return "", errors.Wrap(err, "generating random bytes")
}
return base64.URLEncoding.EncodeToString(b), nil
@ -49,7 +47,7 @@ func GetToken(db *gorm.DB, userID int, kind string) (database.Token, error) {
tokenVal, genErr := generateRandomToken(16)
if genErr != nil {
return tok, pkgErrors.Wrap(genErr, "generating token value")
return tok, errors.Wrap(genErr, "generating token value")
}
if errors.Is(err, gorm.ErrRecordNotFound) {
@ -59,12 +57,12 @@ func GetToken(db *gorm.DB, userID int, kind string) (database.Token, error) {
Value: tokenVal,
}
if err := db.Save(&tok).Error; err != nil {
return tok, pkgErrors.Wrap(err, "saving token")
return tok, errors.Wrap(err, "saving token")
}
return tok, nil
} else if err != nil {
return tok, pkgErrors.Wrap(err, "finding token")
return tok, errors.Wrap(err, "finding token")
}
return tok, nil

View file

@ -0,0 +1,83 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package mailer
import (
"testing"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/testutils"
)
func TestGetToken(t *testing.T) {
db := testutils.InitMemoryDB(t)
userID := 1
tokenType := "email_verification"
t.Run("creates new token", func(t *testing.T) {
token, err := GetToken(db, userID, tokenType)
if err != nil {
t.Fatalf("GetToken failed: %v", err)
}
if token.UserID != userID {
t.Errorf("expected UserID %d, got %d", userID, token.UserID)
}
if token.Type != tokenType {
t.Errorf("expected Type %s, got %s", tokenType, token.Type)
}
if token.Value == "" {
t.Error("expected non-empty token Value")
}
if token.UsedAt != nil {
t.Error("expected UsedAt to be nil for new token")
}
})
t.Run("reuses unused token", func(t *testing.T) {
// Get token again - should return the same one
token2, err := GetToken(db, userID, tokenType)
if err != nil {
t.Fatalf("second GetToken failed: %v", err)
}
// Get first token to compare
var token1 database.Token
if err := db.Where("user_id = ? AND type = ?", userID, tokenType).First(&token1).Error; err != nil {
t.Fatalf("failed to get first token: %v", err)
}
if token1.ID != token2.ID {
t.Errorf("expected same token ID %d, got %d", token1.ID, token2.ID)
}
if token1.Value != token2.Value {
t.Errorf("expected same token Value %s, got %s", token1.Value, token2.Value)
}
// Verify only one token exists in database
var count int64
if err := db.Model(&database.Token{}).Where("user_id = ? AND type = ?", userID, tokenType).Count(&count).Error; err != nil {
t.Fatalf("failed to count tokens: %v", err)
}
if count != 1 {
t.Errorf("expected 1 token in database, got %d", count)
}
})
}

View file

@ -42,16 +42,3 @@ type WelcomeTmplData struct {
AccountEmail string
WebURL string
}
// InactiveReminderTmplData is a template data for welcome emails
type InactiveReminderTmplData struct {
SampleNoteUUID string
WebURL string
Token string
}
// EmailTypeSubscriptionConfirmationTmplData is a template data for reset password emails
type EmailTypeSubscriptionConfirmationTmplData struct {
AccountEmail string
WebURL string
}

View file

@ -21,8 +21,8 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
"os"
"github.com/dnote/dnote/pkg/clock"
"github.com/dnote/dnote/pkg/server/app"
@ -30,54 +30,81 @@ import (
"github.com/dnote/dnote/pkg/server/config"
"github.com/dnote/dnote/pkg/server/controllers"
"github.com/dnote/dnote/pkg/server/database"
"github.com/dnote/dnote/pkg/server/job"
"github.com/dnote/dnote/pkg/server/log"
"github.com/dnote/dnote/pkg/server/mailer"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"github.com/pkg/errors"
"gorm.io/gorm"
)
var port = flag.String("port", "3000", "port to connect to")
func initDB(c config.Config) *gorm.DB {
db, err := gorm.Open(postgres.Open(c.DB.GetConnectionStr()), &gorm.Config{})
if err != nil {
panic(errors.Wrap(err, "opening database connection"))
}
func initDB(dbPath string) *gorm.DB {
db := database.Open(dbPath)
database.InitSchema(db)
database.Migrate(db)
return db
}
func initApp(cfg config.Config) app.App {
db := initDB(cfg)
db := initDB(cfg.DBPath)
emailBackend, err := mailer.NewDefaultBackend(cfg.IsProd())
if err != nil {
emailBackend = &mailer.DefaultBackend{Enabled: false}
} else {
log.Info("Email backend configured")
}
return app.App{
DB: db,
Clock: clock.New(),
EmailTemplates: mailer.NewTemplates(),
EmailBackend: &mailer.SimpleBackendImplementation{},
Config: cfg,
HTTP500Page: cfg.HTTP500Page,
DB: db,
Clock: clock.New(),
EmailTemplates: mailer.NewTemplates(),
EmailBackend: emailBackend,
HTTP500Page: cfg.HTTP500Page,
AppEnv: cfg.AppEnv,
WebURL: cfg.WebURL,
DisableRegistration: cfg.DisableRegistration,
Port: cfg.Port,
DBPath: cfg.DBPath,
AssetBaseURL: cfg.AssetBaseURL,
}
}
func runJob(a app.App) error {
runner, err := job.NewRunner(a.DB, a.Clock, a.EmailTemplates, a.EmailBackend, a.Config)
func startCmd(args []string) {
startFlags := flag.NewFlagSet("start", flag.ExitOnError)
startFlags.Usage = func() {
fmt.Printf(`Usage:
dnote-server start [flags]
Flags:
`)
startFlags.PrintDefaults()
}
appEnv := startFlags.String("appEnv", "", "Application environment (env: APP_ENV, default: PRODUCTION)")
port := startFlags.String("port", "", "Server port (env: PORT, default: 3000)")
webURL := startFlags.String("webUrl", "", "Full URL to server without trailing slash (env: WebURL, example: https://example.com)")
dbPath := startFlags.String("dbPath", "", "Path to SQLite database file (env: DBPath, default: $XDG_DATA_HOME/dnote/server.db)")
disableRegistration := startFlags.Bool("disableRegistration", false, "Disable user registration (env: DisableRegistration, default: false)")
logLevel := startFlags.String("logLevel", "", "Log level: debug, info, warn, or error (env: LOG_LEVEL, default: info)")
startFlags.Parse(args)
cfg, err := config.New(config.Params{
AppEnv: *appEnv,
Port: *port,
WebURL: *webURL,
DBPath: *dbPath,
DisableRegistration: *disableRegistration,
LogLevel: *logLevel,
})
if err != nil {
return errors.Wrap(err, "getting a job runner")
}
if err := runner.Do(); err != nil {
return errors.Wrap(err, "running job")
fmt.Printf("Error: %s\n\n", err)
startFlags.Usage()
os.Exit(1)
}
return nil
}
func startCmd() {
cfg := config.Load()
cfg.SetAssetBaseURL("/static")
// Set log level
log.SetLevel(cfg.LogLevel)
app := initApp(cfg)
defer func() {
@ -87,13 +114,6 @@ func startCmd() {
}
}()
if err := database.Migrate(app.DB); err != nil {
panic(errors.Wrap(err, "running migrations"))
}
if err := runJob(app); err != nil {
panic(errors.Wrap(err, "running job"))
}
ctl := controllers.New(&app)
rc := controllers.RouteConfig{
WebRoutes: controllers.NewWebRoutes(&app, ctl),
@ -106,8 +126,15 @@ func startCmd() {
panic(errors.Wrap(err, "initializing router"))
}
log.Printf("Dnote version %s is running on port %s", buildinfo.Version, *port)
log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%s", *port), r))
log.WithFields(log.Fields{
"version": buildinfo.Version,
"port": cfg.Port,
}).Info("Dnote server starting")
if err := http.ListenAndServe(fmt.Sprintf(":%s", cfg.Port), r); err != nil {
log.ErrorWrap(err, "server failed")
os.Exit(1)
}
}
func versionCmd() {
@ -115,29 +142,33 @@ func versionCmd() {
}
func rootCmd() {
fmt.Printf(`Dnote server - a simple personal knowledge base
fmt.Printf(`Dnote server - a simple command line notebook
Usage:
dnote-server [command]
dnote-server [command] [flags]
Available commands:
start: Start the server
start: Start the server (use 'dnote-server start --help' for flags)
version: Print the version
`)
}
func main() {
flag.Parse()
cmd := flag.Arg(0)
if len(os.Args) < 2 {
rootCmd()
return
}
cmd := os.Args[1]
switch cmd {
case "":
rootCmd()
case "start":
startCmd()
startCmd(os.Args[2:])
case "version":
versionCmd()
default:
fmt.Printf("Unknown command %s", cmd)
fmt.Printf("Unknown command %s\n", cmd)
rootCmd()
os.Exit(1)
}
}

Some files were not shown because too many files have changed in this diff Show more