diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 920af76f..ac44b4f3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,25 +11,6 @@ jobs: build: runs-on: ubuntu-22.04 - services: - postgres: - image: postgres:14 - env: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: dnote_test - POSTGRES_PORT: 5432 - # Wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - - # Expose port to the host - ports: - - 5432:5432 - steps: - uses: actions/checkout@v5 - uses: actions/setup-go@v6 diff --git a/.github/workflows/release-server.yml b/.github/workflows/release-server.yml new file mode 100644 index 00000000..75657e1e --- /dev/null +++ b/.github/workflows/release-server.yml @@ -0,0 +1,91 @@ +name: Release Server + +on: + push: + tags: + - 'server-v*' + +jobs: + release: + runs-on: ubuntu-22.04 + permissions: + contents: write + + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-go@v6 + with: + go-version: '>=1.25.0' + - uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Extract version from tag + id: version + run: | + TAG=${GITHUB_REF#refs/tags/server-v} + echo "version=$TAG" >> $GITHUB_OUTPUT + echo "Releasing version: $TAG" + + - name: Install dependencies + run: make install + + - name: Run tests + run: make test + + - name: Build server + run: make version=${{ steps.version.outputs.version }} build-server + + - name: Prepare Docker build context + run: | + VERSION="${{ steps.version.outputs.version }}" + cp build/server/dnote_server_${VERSION}_linux_amd64.tar.gz host/docker/ + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: ./host/docker + push: true + tags: | + dnote/dnote:${{ steps.version.outputs.version }} + dnote/dnote:latest + build-args: | + tarballName=dnote_server_${{ steps.version.outputs.version }}_linux_amd64.tar.gz + + - name: Create GitHub release + env: + GH_TOKEN: ${{ github.token }} + run: | + VERSION="${{ steps.version.outputs.version }}" + TAG="server-v${VERSION}" + + # Determine if prerelease (version not matching major.minor.patch) + FLAGS="" + if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + FLAGS="--prerelease" + fi + + gh release create "$TAG" \ + build/server/*.tar.gz \ + build/server/*_checksums.txt \ + $FLAGS \ + --title="$TAG" \ + --notes="Please see the [CHANGELOG](https://github.com/dnote/dnote/blob/master/CHANGELOG.md)" \ + --draft + + - name: Push to Docker Hub + env: + DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} + DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }} + run: | + VERSION="${{ steps.version.outputs.version }}" + + echo "$DOCKER_TOKEN" | docker login -u "$DOCKER_USERNAME" --password-stdin + docker push dnote/dnote:${VERSION} + docker push dnote/dnote:latest diff --git a/.gitignore b/.gitignore index c34f93b0..57d82ddc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ node_modules /test tmp +*.db +server diff --git a/Makefile b/Makefile index 09c3513f..3b91d6ae 100644 --- a/Makefile +++ b/Makefile @@ -48,12 +48,6 @@ test-e2e: @(${currentDir}/scripts/e2e/test.sh) .PHONY: test-e2e -test-selfhost: - @echo "==> running a smoke test for self-hosting" - - @${currentDir}/host/smoketest/run_test.sh ${tarballPath} -.PHONY: test-selfhost - # development dev-server: @echo "==> running dev environment" diff --git a/SELF_HOSTING.md b/SELF_HOSTING.md index 9edc82e1..0c52733c 100644 --- a/SELF_HOSTING.md +++ b/SELF_HOSTING.md @@ -4,48 +4,33 @@ This guide documents the steps for installing the Dnote server on your own machi ## Overview -Dnote server comes as a single binary file that you can simply download and run. It uses Postgres as the database. +Dnote server comes as a single binary file that you can simply download and run. It uses SQLite as the database. ## Installation -1. Install Postgres 11+. -2. Create a `dnote` database by running `createdb dnote` -3. Download the official Dnote server release from the [release page](https://github.com/dnote/dnote/releases). -4. Extract the archive and move the `dnote-server` executable to `/usr/local/bin`. +1. Download the official Dnote server release from the [release page](https://github.com/dnote/dnote/releases). +2. Extract the archive and move the `dnote-server` executable to `/usr/local/bin`. ```bash tar -xzf dnote-server-$version-$os.tar.gz mv ./dnote-server /usr/local/bin ``` -4. Run Dnote +3. Run Dnote ```bash -GO_ENV=PRODUCTION \ -OnPremises=true \ -DBHost=localhost \ -DBPort=5432 \ -DBName=dnote \ -DBUser=$user \ -DBPassword=$password \ -WebURL=$webURL \ -SmtpHost=$SmtpHost \ -SmtpPort=$SmtpPort \ -SmtpUsername=$SmtpUsername \ -SmtpPassword=$SmtpPassword \ -DisableRegistration=false \ - dnote-server start +dnote-server start --webUrl=$webURL ``` -Replace `$user`, `$password` with the credentials of the Postgres user that owns the `dnote` database. - Replace `$webURL` with the full URL to your server, without a trailing slash (e.g. `https://your.server`). -Replace `$SmtpHost`, `SmtpPort`, `$SmtpUsername`, `$SmtpPassword` with actual values, if you would like to receive spaced repetition through email. +Additional flags: +- `--port`: Server port (default: `3000`) +- `--disableRegistration`: Disable user registration (default: `false`) +- `--logLevel`: Log level: `debug`, `info`, `warn`, or `error` (default: `info`) +- `--appEnv`: environment (default: `PRODUCTION`) -Replace `DisableRegistration` to `true` if you would like to disable user registrations. - -By default, dnote server will run on the port 3000. +You can also use environment variables: `PORT`, `WebURL`, `DisableRegistration`, `LOG_LEVEL`, `APP_ENV`. ## Configuration @@ -127,33 +112,31 @@ User=$user Restart=always RestartSec=3 WorkingDirectory=/home/$user -ExecStart=/usr/local/bin/dnote-server start -Environment=GO_ENV=PRODUCTION -Environment=OnPremises=true -Environment=DBHost=localhost -Environment=DBPort=5432 -Environment=DBName=dnote -Environment=DBUser=$DBUser -Environment=DBPassword=$DBPassword -Environment=DBSkipSSL=true -Environment=WebURL=$WebURL -Environment=SmtpHost= -Environment=SmtpPort= -Environment=SmtpUsername= -Environment=SmtpPassword= +ExecStart=/usr/local/bin/dnote-server start --webUrl=$WebURL [Install] WantedBy=multi-user.target ``` -Replace `$user`, `$WebURL`, `$DBUser`, and `$DBPassword` with the actual values. +Replace `$user` and `$WebURL` with the actual values. -Optionally, if you would like to send spaced repetitions throught email, populate `SmtpHost`, `SmtpPort`, `SmtpUsername`, and `SmtpPassword`. +By default, the database will be stored at `$XDG_DATA_HOME/dnote/server.db` (typically `~/.local/share/dnote/server.db`). To use a custom location, add `--dbPath=/path/to/database.db` to the `ExecStart` command. 2. Reload the change by running `sudo systemctl daemon-reload`. 3. Enable the Daemon by running `sudo systemctl enable dnote`.` 4. Start the Daemon by running `sudo systemctl start dnote` +### Optional: Email Support + +To enable sending emails, add the following environment variables to your configuration. But they are not required. + +- `SmtpHost` - SMTP server hostname +- `SmtpPort` - SMTP server port +- `SmtpUsername` - SMTP username +- `SmtpPassword` - SMTP password + +For systemd, add these as additional `Environment=` lines in `/etc/systemd/system/dnote.service`. + ### Configure clients Let's configure Dnote clients to connect to the self-hosted web API endpoint. @@ -166,7 +149,7 @@ The following is an example configuration: ```yaml editor: nvim -apiEndpoint: https://api.getdnote.com +apiEndpoint: https://localhost:3000/api ``` Simply change the value for `apiEndpoint` to a full URL to the self-hosted instance, followed by '/api', and save the configuration file. @@ -177,7 +160,3 @@ e.g. editor: nvim apiEndpoint: my-dnote-server.com/api ``` - -#### Browser extension - -Navigate into the 'Settings' tab and set the values for 'API URL', and 'Web URL'. diff --git a/go.mod b/go.mod index 275818a9..3b87c5ec 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/dnote/dnote go 1.25 require ( - github.com/aymerick/douceur v0.2.0 github.com/dnote/actions v0.2.0 github.com/fatih/color v1.18.0 github.com/google/go-cmp v0.7.0 @@ -12,42 +11,34 @@ require ( github.com/gorilla/csrf v1.7.3 github.com/gorilla/mux v1.8.1 github.com/gorilla/schema v1.4.1 - github.com/joho/godotenv v1.5.1 - github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.32 github.com/pkg/errors v0.9.1 github.com/radovskyb/watcher v1.0.7 github.com/robfig/cron v1.2.0 - github.com/rubenv/sql-migrate v1.8.0 github.com/sergi/go-diff v1.3.1 github.com/spf13/cobra v1.10.1 golang.org/x/crypto v0.42.0 golang.org/x/time v0.13.0 gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df gopkg.in/yaml.v2 v2.4.0 - gorm.io/driver/postgres v1.5.7 - gorm.io/gorm v1.25.7 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.30.0 ) require ( - github.com/PuerkitoBio/goquery v1.10.3 // indirect - github.com/andybalholm/cascadia v1.3.3 // indirect - github.com/go-gorp/gorp/v3 v3.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/gorilla/css v1.0.1 // indirect github.com/gorilla/securecookie v1.1.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/spf13/pflag v1.0.10 // indirect - golang.org/x/net v0.44.0 // indirect + github.com/stretchr/testify v1.8.1 // indirect golang.org/x/sys v0.36.0 // indirect golang.org/x/term v0.35.0 // indirect golang.org/x/text v0.29.0 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/go.sum b/go.sum index 6d099387..e175da0d 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,5 @@ -github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= -github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= -github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= -github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= -github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= -github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -12,12 +7,7 @@ github.com/dnote/actions v0.2.0 h1:P1ut2/QRKwfAzIIB374vN9A4IanU94C/payEocvngYo= github.com/dnote/actions v0.2.0/go.mod h1:bBIassLhppVQdbC3iaE92SHBpM1HOVe+xZoAlj9ROxw= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= -github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs= -github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw= -github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY= @@ -30,8 +20,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= -github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= -github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E= @@ -40,47 +28,35 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= -github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= -github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/poy/onpar v1.1.2 h1:QaNrNiZx0+Nar5dLgTVp5mXkyoVFIbepjyEoGSnhbAY= -github.com/poy/onpar v1.1.2/go.mod h1:6X8FLNoxyr9kkmnlqpK6LSoiOtrO6MICtWwEuWkLjzg= github.com/radovskyb/watcher v1.0.7 h1:AYePLih6dpmS32vlHfhCeli8127LzkIgwJGcwwe8tUE= github.com/radovskyb/watcher v1.0.7/go.mod h1:78okwvY5wPdzcb1UYnip1pvrZNIVEIh/Cm+ZuvsUYIg= github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ= github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rubenv/sql-migrate v1.8.0 h1:dXnYiJk9k3wetp7GfQbKJcPHjVJL6YK19tKj8t2Ns0o= -github.com/rubenv/sql-migrate v1.8.0/go.mod h1:F2bGFBwCU+pnmbtNYDeKvSuvL6lBVtXDXUUv5t+u1qw= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= @@ -90,88 +66,24 @@ github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= -golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= -golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= @@ -187,7 +99,7 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= -gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A= -gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= +gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/host/docker/compose.yml b/host/docker/compose.yml index 446e841d..d2c2e5ed 100644 --- a/host/docker/compose.yml +++ b/host/docker/compose.yml @@ -1,35 +1,14 @@ version: "3" services: - postgres: - image: postgres:14-alpine - environment: - POSTGRES_USER: dnote - POSTGRES_PASSWORD: dnote - POSTGRES_DB: dnote - volumes: - - ./dnote_data:/var/lib/postgresql/data - restart: always - dnote: image: dnote/dnote:latest environment: - GO_ENV: PRODUCTION - DBSkipSSL: "true" - DBHost: postgres - DBPort: 5432 - DBName: dnote - DBUser: dnote - DBPassword: dnote + APP_ENV: PRODUCTION WebURL: localhost:3000 - OnPremises: "true" - SmtpHost: - SmtpPort: - SmtpUsername: - SmtpPassword: DisableRegistration: "false" ports: - 3000:3000 - depends_on: - - postgres + volumes: + - ./dnote_data:/data restart: always diff --git a/host/docker/entrypoint.sh b/host/docker/entrypoint.sh index 8fb62de9..0fee185c 100755 --- a/host/docker/entrypoint.sh +++ b/host/docker/entrypoint.sh @@ -1,25 +1,6 @@ #!/bin/sh -wait_for_db() { - HOST=${DBHost:-postgres} - PORT=${DBPort:-5432} - echo "Waiting for the database connection..." - - attempts=0 - max_attempts=10 - while [ $attempts -lt $max_attempts ]; do - nc -z "${HOST}" "${PORT}" 2>/dev/null && break - echo "Waiting for db at ${HOST}:${PORT}..." - sleep 5 - attempts=$((attempts+1)) - done - - if [ $attempts -eq $max_attempts ]; then - echo "Timed out while waiting for db at ${HOST}:${PORT}" - exit 1 - fi -} - -wait_for_db +# Set default DBPath to /data if not specified +export DBPath=${DBPath:-/data/dnote.db} exec "$@" diff --git a/host/smoketest/.gitignore b/host/smoketest/.gitignore deleted file mode 100644 index 463ebfd4..00000000 --- a/host/smoketest/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/volume diff --git a/host/smoketest/README.md b/host/smoketest/README.md deleted file mode 100644 index a428896a..00000000 --- a/host/smoketest/README.md +++ /dev/null @@ -1,9 +0,0 @@ -This directory contains a smoke test for running a self-hosted instance using a virtual machine. - -## Instruction - -The following script will set up a test environment in Vagrant and run the test. - -``` -./run_test.sh -``` diff --git a/host/smoketest/Vagrantfile b/host/smoketest/Vagrantfile deleted file mode 100644 index 2eaee5dd..00000000 --- a/host/smoketest/Vagrantfile +++ /dev/null @@ -1,9 +0,0 @@ -# -*- mode: ruby -*- - -Vagrant.configure("2") do |config| - config.vm.box = "ubuntu/jammy64" - config.vm.synced_folder './volume', '/vagrant' - config.vm.network "forwarded_port", guest: 2300, host: 2300 - - config.vm.provision 'shell', path: './setup.sh', privileged: false -end diff --git a/host/smoketest/run_test.sh b/host/smoketest/run_test.sh deleted file mode 100755 index 8137dfcd..00000000 --- a/host/smoketest/run_test.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env bash -# run_test.sh builds a fresh server image, and mounts it on a fresh -# virtual machine and runs a smoke test. If a tarball path is not provided, -# this script builds a new version and uses it. -set -ex - -# tarballPath is an absolute path to a release tarball containing the dnote server. -tarballPath=$1 - -dir=$(dirname "${BASH_SOURCE[0]}") -projectDir="$dir/../.." - -# build -if [ -z "$tarballPath" ]; then - pushd "$projectDir" - make version=integration_test build-server - popd - tarballPath="$projectDir/build/server/dnote_server_integration_test_linux_amd64.tar.gz" -fi - -pushd "$dir" - -# start a virtual machine -volume="$dir/volume" -rm -rf "$volume" -mkdir -p "$volume" -cp "$tarballPath" "$volume" -cp "$dir/testsuite.sh" "$volume" - -vagrant up - -# run tests -set +e -if ! vagrant ssh -c "/vagrant/testsuite.sh"; then - echo "Test failed. Please see the output." - vagrant halt - exit 1 -fi -set -e - -vagrant halt -popd diff --git a/host/smoketest/setup.sh b/host/smoketest/setup.sh deleted file mode 100755 index 575b7fd8..00000000 --- a/host/smoketest/setup.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env bash -set -ex - -sudo apt-get install wget ca-certificates -wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - -sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt/ `lsb_release -cs`-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' - -sudo apt-get update -sudo apt-get install -y postgresql-14 - -# set up database -sudo usermod -a -G sudo postgres -cd /var/lib/postgresql -sudo -u postgres createdb dnote -sudo -u postgres psql -c "ALTER USER postgres PASSWORD 'postgres';" - -# allow connection from host and allow to connect without password -sudo sed -i "/port*/a listen_addresses = '*'" /etc/postgresql/14/main/postgresql.conf -sudo sed -i 's/host.*all.*.all.*md5/# &/' /etc/postgresql/14/main/pg_hba.conf -sudo sed -i "$ a host all all all trust" /etc/postgresql/14/main/pg_hba.conf -sudo service postgresql restart diff --git a/host/smoketest/testsuite.sh b/host/smoketest/testsuite.sh deleted file mode 100755 index 4fa05af3..00000000 --- a/host/smoketest/testsuite.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env bash -# testsuite.sh runs the smoke tests for a self-hosted instance. -# It is meant to be run inside a virtual machine which has been -# set up by an entry script. -set -eux - -echo 'Running a smoke test' - -cd /var/lib/postgresql -sudo -u postgres dropdb dnote -sudo -u postgres createdb dnote - -cd /vagrant - -tar -xvf dnote_server_integration_test_linux_amd64.tar.gz - -GO_ENV=PRODUCTION \ - DBHost=localhost \ - DBPort=5432 \ - DBName=dnote \ - DBUser=postgres \ - DBPassword=postgres \ - WebURL=localhost:3000 \ - ./dnote-server -port 2300 start & sleep 3 - -assert_http_status() { - url=$1 - expected=$2 - - echo "======== [TEST CASE] asserting response status code for $url ========" - - got=$(curl --write-out %"{http_code}" --silent --output /dev/null "$url") - - if [ "$got" != "$expected" ]; then - echo "======== ASSERTION FAILED ========" - echo "status code for $url: expected: $expected got: $got" - echo "==================================" - exit 1 - fi -} - -assert_http_status http://localhost:2300 "302" -assert_http_status http://localhost:2300/health "200" - -echo "======== [SUCCESS] TEST PASSED! ========" diff --git a/pkg/assert/assert.go b/pkg/assert/assert.go index 94c7b411..000d7b48 100644 --- a/pkg/assert/assert.go +++ b/pkg/assert/assert.go @@ -22,7 +22,7 @@ package assert import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "reflect" "runtime/debug" @@ -138,7 +138,7 @@ func EqualJSON(t *testing.T, a, b, message string) { // expected func StatusCodeEquals(t *testing.T, res *http.Response, expected int, message string) { if res.StatusCode != expected { - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(errors.Wrap(err, "reading body")) } diff --git a/pkg/cli/client/client.go b/pkg/cli/client/client.go index 59cfdb59..b90ddaca 100644 --- a/pkg/cli/client/client.go +++ b/pkg/cli/client/client.go @@ -23,7 +23,7 @@ package client import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/url" "strconv" @@ -95,7 +95,7 @@ func checkRespErr(res *http.Response) error { return nil } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { return errors.Wrapf(err, "server responded with %d but client could not read the response body", res.StatusCode) } @@ -169,7 +169,7 @@ func GetSyncState(ctx context.DnoteCtx) (GetSyncStateResp, error) { return ret, errors.Wrap(err, "constructing http request") } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { return ret, errors.Wrap(err, "reading the response body") } @@ -233,7 +233,7 @@ func GetSyncFragment(ctx context.DnoteCtx, afterUSN int) (GetSyncFragmentResp, e path := fmt.Sprintf("/v3/sync/fragment?%s", queryStr) res, err := doAuthorizedReq(ctx, "GET", path, "", nil) - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { return GetSyncFragmentResp{}, errors.Wrap(err, "reading the response body") } diff --git a/pkg/cli/cmd/edit/note.go b/pkg/cli/cmd/edit/note.go index ba8968c3..f6019778 100644 --- a/pkg/cli/cmd/edit/note.go +++ b/pkg/cli/cmd/edit/note.go @@ -20,7 +20,7 @@ package edit import ( "database/sql" - "io/ioutil" + "os" "strconv" "github.com/dnote/dnote/pkg/cli/context" @@ -45,7 +45,7 @@ func waitEditorNoteContent(ctx context.DnoteCtx, note database.Note) (string, er return "", errors.Wrap(err, "getting temporarily content file path") } - if err := ioutil.WriteFile(fpath, []byte(note.Body), 0644); err != nil { + if err := os.WriteFile(fpath, []byte(note.Body), 0644); err != nil { return "", errors.Wrap(err, "preparing tmp content file") } diff --git a/pkg/cli/config/config.go b/pkg/cli/config/config.go index 3ef76a6b..b0faa9cf 100644 --- a/pkg/cli/config/config.go +++ b/pkg/cli/config/config.go @@ -20,7 +20,7 @@ package config import ( "fmt" - "io/ioutil" + "os" "github.com/dnote/dnote/pkg/cli/consts" "github.com/dnote/dnote/pkg/cli/context" @@ -66,7 +66,7 @@ func Read(ctx context.DnoteCtx) (Config, error) { var ret Config configPath := GetPath(ctx) - b, err := ioutil.ReadFile(configPath) + b, err := os.ReadFile(configPath) if err != nil { return ret, errors.Wrap(err, "reading config file") } @@ -88,7 +88,7 @@ func Write(ctx context.DnoteCtx, cf Config) error { return errors.Wrap(err, "marshalling config into YAML") } - err = ioutil.WriteFile(path, b, 0644) + err = os.WriteFile(path, b, 0644) if err != nil { return errors.Wrap(err, "writing the config file") } diff --git a/pkg/cli/infra/init.go b/pkg/cli/infra/init.go index cecd022f..72286cad 100644 --- a/pkg/cli/infra/init.go +++ b/pkg/cli/infra/init.go @@ -32,11 +32,11 @@ import ( "github.com/dnote/dnote/pkg/cli/consts" "github.com/dnote/dnote/pkg/cli/context" "github.com/dnote/dnote/pkg/cli/database" - "github.com/dnote/dnote/pkg/cli/dirs" "github.com/dnote/dnote/pkg/cli/log" "github.com/dnote/dnote/pkg/cli/migrate" "github.com/dnote/dnote/pkg/cli/utils" "github.com/dnote/dnote/pkg/clock" + "github.com/dnote/dnote/pkg/dirs" "github.com/pkg/errors" "github.com/spf13/cobra" ) diff --git a/pkg/cli/migrate/legacy.go b/pkg/cli/migrate/legacy.go index d0dab97b..dc5441c7 100644 --- a/pkg/cli/migrate/legacy.go +++ b/pkg/cli/migrate/legacy.go @@ -23,7 +23,6 @@ package migrate import ( "encoding/json" "fmt" - "io/ioutil" "os" "time" @@ -232,7 +231,7 @@ func readSchema(ctx context.DnoteCtx) (schema, error) { path := getSchemaPath(ctx) - b, err := ioutil.ReadFile(path) + b, err := os.ReadFile(path) if err != nil { return ret, errors.Wrap(err, "Failed to read schema file") } @@ -252,7 +251,7 @@ func writeSchema(ctx context.DnoteCtx, s schema) error { return errors.Wrap(err, "Failed to marshal schema into yaml") } - if err := ioutil.WriteFile(path, d, 0644); err != nil { + if err := os.WriteFile(path, d, 0644); err != nil { return errors.Wrap(err, "Failed to write schema file") } @@ -504,7 +503,7 @@ func migrateToV1(ctx context.DnoteCtx) error { func migrateToV2(ctx context.DnoteCtx) error { notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote) - b, err := ioutil.ReadFile(notePath) + b, err := os.ReadFile(notePath) if err != nil { return errors.Wrap(err, "Failed to read the note file") } @@ -548,7 +547,7 @@ func migrateToV2(ctx context.DnoteCtx) error { return errors.Wrap(err, "Failed to marshal new dnote into JSON") } - err = ioutil.WriteFile(notePath, d, 0644) + err = os.WriteFile(notePath, d, 0644) if err != nil { return errors.Wrap(err, "Failed to write the new dnote into the file") } @@ -561,7 +560,7 @@ func migrateToV3(ctx context.DnoteCtx) error { notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote) actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote) - b, err := ioutil.ReadFile(notePath) + b, err := os.ReadFile(notePath) if err != nil { return errors.Wrap(err, "Failed to read the note file") } @@ -615,7 +614,7 @@ func migrateToV3(ctx context.DnoteCtx) error { return errors.Wrap(err, "Failed to marshal actions into JSON") } - err = ioutil.WriteFile(actionsPath, a, 0644) + err = os.WriteFile(actionsPath, a, 0644) if err != nil { return errors.Wrap(err, "Failed to write the actions into a file") } @@ -647,7 +646,7 @@ func getEditorCommand() string { func migrateToV4(ctx context.DnoteCtx) error { configPath := fmt.Sprintf("%s/dnoterc", ctx.Paths.LegacyDnote) - b, err := ioutil.ReadFile(configPath) + b, err := os.ReadFile(configPath) if err != nil { return errors.Wrap(err, "Failed to read the config file") } @@ -668,7 +667,7 @@ func migrateToV4(ctx context.DnoteCtx) error { return errors.Wrap(err, "Failed to marshal config into JSON") } - err = ioutil.WriteFile(configPath, data, 0644) + err = os.WriteFile(configPath, data, 0644) if err != nil { return errors.Wrap(err, "Failed to write the config into a file") } @@ -680,7 +679,7 @@ func migrateToV4(ctx context.DnoteCtx) error { func migrateToV5(ctx context.DnoteCtx) error { actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote) - b, err := ioutil.ReadFile(actionsPath) + b, err := os.ReadFile(actionsPath) if err != nil { return errors.Wrap(err, "reading the actions file") } @@ -738,7 +737,7 @@ func migrateToV5(ctx context.DnoteCtx) error { if err != nil { return errors.Wrap(err, "marshalling result into JSON") } - err = ioutil.WriteFile(actionsPath, a, 0644) + err = os.WriteFile(actionsPath, a, 0644) if err != nil { return errors.Wrap(err, "writing the result into a file") } @@ -750,7 +749,7 @@ func migrateToV5(ctx context.DnoteCtx) error { func migrateToV6(ctx context.DnoteCtx) error { notePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote) - b, err := ioutil.ReadFile(notePath) + b, err := os.ReadFile(notePath) if err != nil { return errors.Wrap(err, "Failed to read the note file") } @@ -791,7 +790,7 @@ func migrateToV6(ctx context.DnoteCtx) error { return errors.Wrap(err, "Failed to marshal new dnote into JSON") } - err = ioutil.WriteFile(notePath, d, 0644) + err = os.WriteFile(notePath, d, 0644) if err != nil { return errors.Wrap(err, "Failed to write the new dnote into the file") } @@ -805,7 +804,7 @@ func migrateToV6(ctx context.DnoteCtx) error { func migrateToV7(ctx context.DnoteCtx) error { actionPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote) - b, err := ioutil.ReadFile(actionPath) + b, err := os.ReadFile(actionPath) if err != nil { return errors.Wrap(err, "reading actions file") } @@ -857,7 +856,7 @@ func migrateToV7(ctx context.DnoteCtx) error { return errors.Wrap(err, "marshalling new actions") } - err = ioutil.WriteFile(actionPath, d, 0644) + err = os.WriteFile(actionPath, d, 0644) if err != nil { return errors.Wrap(err, "writing new actions to a file") } @@ -874,7 +873,7 @@ func migrateToV8(ctx context.DnoteCtx) error { // 1. Migrate the the dnote file dnoteFilePath := fmt.Sprintf("%s/dnote", ctx.Paths.LegacyDnote) - b, err := ioutil.ReadFile(dnoteFilePath) + b, err := os.ReadFile(dnoteFilePath) if err != nil { return errors.Wrap(err, "reading the notes") } @@ -914,7 +913,7 @@ func migrateToV8(ctx context.DnoteCtx) error { // 2. Migrate the actions file actionsPath := fmt.Sprintf("%s/actions", ctx.Paths.LegacyDnote) - b, err = ioutil.ReadFile(actionsPath) + b, err = os.ReadFile(actionsPath) if err != nil { return errors.Wrap(err, "reading the actions") } @@ -939,7 +938,7 @@ func migrateToV8(ctx context.DnoteCtx) error { // 3. Migrate the timestamps file timestampsPath := fmt.Sprintf("%s/timestamps", ctx.Paths.LegacyDnote) - b, err = ioutil.ReadFile(timestampsPath) + b, err = os.ReadFile(timestampsPath) if err != nil { return errors.Wrap(err, "reading the timestamps") } diff --git a/pkg/cli/migrate/legacy_test.go b/pkg/cli/migrate/legacy_test.go index 211feee7..00ebb7d7 100644 --- a/pkg/cli/migrate/legacy_test.go +++ b/pkg/cli/migrate/legacy_test.go @@ -21,7 +21,6 @@ package migrate import ( "encoding/json" "fmt" - "io/ioutil" "os" "path/filepath" "testing" @@ -65,7 +64,7 @@ func TestMigrateToV1(t *testing.T) { if err != nil { panic(errors.Wrap(err, "Failed to get absolute YAML path").Error()) } - ioutil.WriteFile(yamlPath, []byte{}, 0644) + os.WriteFile(yamlPath, []byte{}, 0644) // execute if err := migrateToV1(ctx); err != nil { diff --git a/pkg/cli/migrate/migrate_test.go b/pkg/cli/migrate/migrate_test.go index bfda7768..cd2619bd 100644 --- a/pkg/cli/migrate/migrate_test.go +++ b/pkg/cli/migrate/migrate_test.go @@ -21,8 +21,8 @@ package migrate import ( "encoding/json" "fmt" - "io/ioutil" "net/http" + "os" "net/http/httptest" "testing" "time" @@ -1079,7 +1079,7 @@ func TestLocalMigration12(t *testing.T) { data := []byte("editor: vim") path := fmt.Sprintf("%s/%s/dnoterc", ctx.Paths.Config, consts.DnoteDirName) - if err := ioutil.WriteFile(path, data, 0644); err != nil { + if err := os.WriteFile(path, data, 0644); err != nil { t.Fatal(errors.Wrap(err, "Failed to write schema file")) } @@ -1090,7 +1090,7 @@ func TestLocalMigration12(t *testing.T) { } // test - b, err := ioutil.ReadFile(path) + b, err := os.ReadFile(path) if err != nil { t.Fatal(errors.Wrap(err, "reading config")) } @@ -1117,7 +1117,7 @@ func TestLocalMigration13(t *testing.T) { data := []byte("editor: vim\napiEndpoint: https://test.com/api") path := fmt.Sprintf("%s/%s/dnoterc", ctx.Paths.Config, consts.DnoteDirName) - if err := ioutil.WriteFile(path, data, 0644); err != nil { + if err := os.WriteFile(path, data, 0644); err != nil { t.Fatal(errors.Wrap(err, "Failed to write schema file")) } @@ -1128,7 +1128,7 @@ func TestLocalMigration13(t *testing.T) { } // test - b, err := ioutil.ReadFile(path) + b, err := os.ReadFile(path) if err != nil { t.Fatal(errors.Wrap(err, "reading config")) } diff --git a/pkg/cli/testutils/main.go b/pkg/cli/testutils/main.go index bdcc00d2..afc12f59 100644 --- a/pkg/cli/testutils/main.go +++ b/pkg/cli/testutils/main.go @@ -23,7 +23,6 @@ import ( "bytes" "encoding/json" "io" - "io/ioutil" "os" "os/exec" "path/filepath" @@ -81,7 +80,7 @@ func WriteFile(ctx context.DnoteCtx, content []byte, filename string) { panic(err) } - if err := ioutil.WriteFile(dp, content, 0644); err != nil { + if err := os.WriteFile(dp, content, 0644); err != nil { panic(err) } } @@ -90,7 +89,7 @@ func WriteFile(ctx context.DnoteCtx, content []byte, filename string) { func ReadFile(ctx context.DnoteCtx, filename string) []byte { path := filepath.Join(ctx.Paths.LegacyDnote, filename) - b, err := ioutil.ReadFile(path) + b, err := os.ReadFile(path) if err != nil { panic(err) } @@ -101,7 +100,7 @@ func ReadFile(ctx context.DnoteCtx, filename string) []byte { // ReadJSON reads JSON fixture to the struct at the destination address func ReadJSON(path string, destination interface{}) { var dat []byte - dat, err := ioutil.ReadFile(path) + dat, err := os.ReadFile(path) if err != nil { panic(errors.Wrap(err, "Failed to load fixture payload")) } diff --git a/pkg/cli/ui/editor.go b/pkg/cli/ui/editor.go index afafcaa3..3501e7ae 100644 --- a/pkg/cli/ui/editor.go +++ b/pkg/cli/ui/editor.go @@ -21,7 +21,6 @@ package ui import ( "fmt" - "io/ioutil" "os" "os/exec" "strings" @@ -122,7 +121,7 @@ func GetEditorInput(ctx context.DnoteCtx, fpath string) (string, error) { return "", errors.Wrap(err, "waiting for the editor") } - b, err := ioutil.ReadFile(fpath) + b, err := os.ReadFile(fpath) if err != nil { return "", errors.Wrap(err, "reading the temporary content file") } diff --git a/pkg/cli/utils/files.go b/pkg/cli/utils/files.go index 1335fc2b..b4b2d1df 100644 --- a/pkg/cli/utils/files.go +++ b/pkg/cli/utils/files.go @@ -20,7 +20,6 @@ package utils import ( "io" - "io/ioutil" "os" "path/filepath" @@ -35,7 +34,7 @@ func ReadFileAbs(relpath string) []byte { panic(err) } - b, err := ioutil.ReadFile(fp) + b, err := os.ReadFile(fp) if err != nil { panic(err) } @@ -80,7 +79,7 @@ func CopyDir(src, dest string) error { return errors.Wrap(err, "creating destination") } - entries, err := ioutil.ReadDir(src) + entries, err := os.ReadDir(src) if err != nil { return errors.Wrap(err, "reading the directory listing for the input") } diff --git a/pkg/cli/dirs/dirs.go b/pkg/dirs/dirs.go similarity index 100% rename from pkg/cli/dirs/dirs.go rename to pkg/dirs/dirs.go diff --git a/pkg/cli/dirs/dirs_test.go b/pkg/dirs/dirs_test.go similarity index 96% rename from pkg/cli/dirs/dirs_test.go rename to pkg/dirs/dirs_test.go index 0e5b0394..6fe3d1e0 100644 --- a/pkg/cli/dirs/dirs_test.go +++ b/pkg/dirs/dirs_test.go @@ -19,7 +19,6 @@ package dirs import ( - "os" "testing" "github.com/dnote/dnote/pkg/assert" @@ -34,7 +33,7 @@ type envTestCase struct { func testCustomDirs(t *testing.T, testCases []envTestCase) { for _, tc := range testCases { - os.Setenv(tc.envKey, tc.envVal) + t.Setenv(tc.envKey, tc.envVal) Reload() diff --git a/pkg/cli/dirs/dirs_unix.go b/pkg/dirs/dirs_unix.go similarity index 100% rename from pkg/cli/dirs/dirs_unix.go rename to pkg/dirs/dirs_unix.go diff --git a/pkg/cli/dirs/dirs_unix_test.go b/pkg/dirs/dirs_unix_test.go similarity index 100% rename from pkg/cli/dirs/dirs_unix_test.go rename to pkg/dirs/dirs_unix_test.go diff --git a/pkg/cli/dirs/dirs_windows.go b/pkg/dirs/dirs_windows.go similarity index 100% rename from pkg/cli/dirs/dirs_windows.go rename to pkg/dirs/dirs_windows.go diff --git a/pkg/cli/dirs/dirs_windows_test.go b/pkg/dirs/dirs_windows_test.go similarity index 100% rename from pkg/cli/dirs/dirs_windows_test.go rename to pkg/dirs/dirs_windows_test.go diff --git a/pkg/e2e/server_test.go b/pkg/e2e/server_test.go new file mode 100644 index 00000000..52a2d311 --- /dev/null +++ b/pkg/e2e/server_test.go @@ -0,0 +1,183 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with Dnote. If not, see . + */ + +package main + +import ( + "fmt" + "net/http" + "os" + "os/exec" + "strings" + "testing" + "time" + + "github.com/dnote/dnote/pkg/assert" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +var testServerBinary string + +func init() { + // Build server binary in temp directory + tmpDir := os.TempDir() + testServerBinary = fmt.Sprintf("%s/dnote-test-server", tmpDir) + buildCmd := exec.Command("go", "build", "-tags", "fts5", "-o", testServerBinary, "../server") + if out, err := buildCmd.CombinedOutput(); err != nil { + panic(fmt.Sprintf("failed to build server: %v\n%s", err, out)) + } +} + +func TestServerStart(t *testing.T) { + tmpDB := t.TempDir() + "/test.db" + port := "13456" // Use different port to avoid conflicts with main test server + + // Start server in background + cmd := exec.Command(testServerBinary, "start", "-port", port) + cmd.Env = append(os.Environ(), + "DBPath="+tmpDB, + "WebURL=http://localhost:"+port, + "APP_ENV=PRODUCTION", + ) + + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + + // Ensure cleanup + cleanup := func() { + if cmd.Process != nil { + cmd.Process.Kill() + cmd.Wait() // Wait for process to fully exit + } + } + defer cleanup() + + // Wait for server to start and migrations to run + time.Sleep(3 * time.Second) + + // Verify server responds to health check + resp, err := http.Get(fmt.Sprintf("http://localhost:%s/health", port)) + if err != nil { + t.Fatalf("failed to reach server health endpoint: %v", err) + } + defer resp.Body.Close() + + assert.Equal(t, resp.StatusCode, 200, "health endpoint should return 200") + + // Kill server before checking database to avoid locks + cleanup() + + // Verify database file was created + if _, err := os.Stat(tmpDB); os.IsNotExist(err) { + t.Fatalf("database file was not created at %s", tmpDB) + } + + // Verify migrations ran by checking database + db, err := gorm.Open(sqlite.Open(tmpDB), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open test database: %v", err) + } + + // Verify migrations ran + var count int64 + if err := db.Raw("SELECT COUNT(*) FROM schema_migrations").Scan(&count).Error; err != nil { + t.Fatalf("schema_migrations table not found: %v", err) + } + if count == 0 { + t.Fatal("no migrations were run") + } + + // Verify FTS table exists and is functional + if err := db.Exec("SELECT * FROM notes_fts LIMIT 1").Error; err != nil { + t.Fatalf("notes_fts table not found or not functional: %v", err) + } +} + +func TestServerVersion(t *testing.T) { + cmd := exec.Command("go", "run", "-tags", "fts5", "../server", "version") + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("version command failed: %v", err) + } + + outputStr := string(output) + if !strings.Contains(outputStr, "dnote-server-") { + t.Errorf("expected version output to contain 'dnote-server-', got: %s", outputStr) + } +} + +func TestServerRootCommand(t *testing.T) { + cmd := exec.Command(testServerBinary) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("server command failed: %v", err) + } + + outputStr := string(output) + assert.Equal(t, strings.Contains(outputStr, "Dnote server - a simple command line notebook"), true, "output should contain description") + assert.Equal(t, strings.Contains(outputStr, "start: Start the server"), true, "output should contain start command") + assert.Equal(t, strings.Contains(outputStr, "version: Print the version"), true, "output should contain version command") +} + +func TestServerStartHelp(t *testing.T) { + cmd := exec.Command(testServerBinary, "start", "--help") + output, _ := cmd.CombinedOutput() + + outputStr := string(output) + assert.Equal(t, strings.Contains(outputStr, "dnote-server start [flags]"), true, "output should contain usage") + assert.Equal(t, strings.Contains(outputStr, "-appEnv"), true, "output should contain appEnv flag") + assert.Equal(t, strings.Contains(outputStr, "-port"), true, "output should contain port flag") + assert.Equal(t, strings.Contains(outputStr, "-webUrl"), true, "output should contain webUrl flag") + assert.Equal(t, strings.Contains(outputStr, "-dbPath"), true, "output should contain dbPath flag") + assert.Equal(t, strings.Contains(outputStr, "-disableRegistration"), true, "output should contain disableRegistration flag") +} + +func TestServerStartInvalidConfig(t *testing.T) { + cmd := exec.Command(testServerBinary, "start") + // Clear WebURL env var so validation fails + cmd.Env = []string{} + + output, err := cmd.CombinedOutput() + + // Should exit with non-zero status + if err == nil { + t.Fatal("expected command to fail with invalid config") + } + + outputStr := string(output) + assert.Equal(t, strings.Contains(outputStr, "Error:"), true, "output should contain error message") + assert.Equal(t, strings.Contains(outputStr, "Invalid WebURL"), true, "output should mention invalid WebURL") + assert.Equal(t, strings.Contains(outputStr, "dnote-server start [flags]"), true, "output should show usage") + assert.Equal(t, strings.Contains(outputStr, "-webUrl"), true, "output should show flags") +} + +func TestServerUnknownCommand(t *testing.T) { + cmd := exec.Command(testServerBinary, "unknown") + output, err := cmd.CombinedOutput() + + // Should exit with non-zero status + if err == nil { + t.Fatal("expected command to fail with unknown command") + } + + outputStr := string(output) + assert.Equal(t, strings.Contains(outputStr, "Unknown command"), true, "output should contain unknown command message") + assert.Equal(t, strings.Contains(outputStr, "Dnote server - a simple command line notebook"), true, "output should show help") +} diff --git a/pkg/e2e/sync_test.go b/pkg/e2e/sync_test.go index d51cbed4..2f3234f2 100644 --- a/pkg/e2e/sync_test.go +++ b/pkg/e2e/sync_test.go @@ -22,7 +22,7 @@ import ( "bytes" "encoding/json" "fmt" - "io/ioutil" + "io" "log" "net/http" "net/http/httptest" @@ -39,16 +39,17 @@ import ( clitest "github.com/dnote/dnote/pkg/cli/testutils" "github.com/dnote/dnote/pkg/clock" "github.com/dnote/dnote/pkg/server/app" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/controllers" "github.com/dnote/dnote/pkg/server/database" "github.com/dnote/dnote/pkg/server/mailer" apitest "github.com/dnote/dnote/pkg/server/testutils" "github.com/pkg/errors" + "gorm.io/gorm" ) var cliBinaryName string var server *httptest.Server +var serverDb *gorm.DB var serverTime = time.Date(2017, time.March, 14, 21, 15, 0, 0, time.UTC) var tmpDirPath string @@ -82,22 +83,22 @@ func clearTmp(t *testing.T) { } func TestMain(m *testing.M) { - // Set up server database - apitest.InitTestDB() + // Set up server database - use file-based DB for e2e tests + dbPath := fmt.Sprintf("%s/server.db", testDir) + serverDb = apitest.InitDB(dbPath) mockClock := clock.NewMock() mockClock.SetNow(serverTime) + a := app.NewTest() + a.Clock = mockClock + a.EmailTemplates = mailer.Templates{} + a.EmailBackend = &apitest.MockEmailbackendImplementation{} + a.DB = serverDb + a.WebURL = os.Getenv("WebURL") + var err error - server, err = controllers.NewServer(&app.App{ - Clock: mockClock, - EmailTemplates: mailer.Templates{}, - EmailBackend: &apitest.MockEmailbackendImplementation{}, - DB: apitest.DB, - Config: config.Config{ - WebURL: os.Getenv("WebURL"), - }, - }) + server, err = controllers.NewServer(&a) if err != nil { panic(errors.Wrap(err, "initializing router")) } @@ -124,11 +125,11 @@ func TestMain(m *testing.M) { // helpers func setupUser(t *testing.T, ctx *context.DnoteCtx) database.User { - user := apitest.SetupUserData() - apitest.SetupAccountData(user, "alice@example.com", "pass1234") + user := apitest.SetupUserData(serverDb) + apitest.SetupAccountData(serverDb, user, "alice@example.com", "pass1234") // log in the user in CLI - session := apitest.SetupSession(t, user) + session := apitest.SetupSession(serverDb, user) cliDatabase.MustExec(t, "inserting session_key", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKey, session.Key) cliDatabase.MustExec(t, "inserting session_key_expiry", ctx.DB, "INSERT INTO system (key, value) VALUES (?, ?)", consts.SystemSessionKeyExpiry, session.ExpiresAt.Unix()) @@ -184,9 +185,9 @@ func doHTTPReq(t *testing.T, method, path, payload, message string, user databas panic(errors.Wrap(err, "constructing http request")) } - res := apitest.HTTPAuthDo(t, req, user) + res := apitest.HTTPAuthDo(t, serverDb, req, user) if res.StatusCode >= 400 { - bs, err := ioutil.ReadAll(res.Body) + bs, err := io.ReadAll(res.Body) if err != nil { panic(errors.Wrap(err, "parsing response body for error")) } @@ -202,8 +203,8 @@ type assertFunc func(t *testing.T, ctx context.DnoteCtx, user database.User, ids func testSyncCmd(t *testing.T, fullSync bool, setup setupFunc, assert assertFunc) { // clean up - apitest.ClearData(apitest.DB) - defer apitest.ClearData(apitest.DB) + apitest.ClearData(serverDb) + defer apitest.ClearData(serverDb) clearTmp(t) @@ -234,7 +235,6 @@ type systemState struct { // checkState compares the state of the client and the server with the given system state func checkState(t *testing.T, ctx context.DnoteCtx, user database.User, expected systemState) { - serverDB := apitest.DB clientDB := ctx.DB var clientBookCount, clientNoteCount int @@ -251,12 +251,12 @@ func checkState(t *testing.T, ctx context.DnoteCtx, user database.User, expected assert.Equal(t, clientLastSyncAt, expected.clientLastSyncAt, "client last_sync_at mismatch") var serverBookCount, serverNoteCount int64 - apitest.MustExec(t, serverDB.Model(&database.Note{}).Count(&serverNoteCount), "counting server notes") - apitest.MustExec(t, serverDB.Model(&database.Book{}).Count(&serverBookCount), "counting api notes") + apitest.MustExec(t, serverDb.Model(&database.Note{}).Count(&serverNoteCount), "counting server notes") + apitest.MustExec(t, serverDb.Model(&database.Book{}).Count(&serverBookCount), "counting api notes") assert.Equal(t, serverNoteCount, expected.serverNoteCount, "server note count mismatch") assert.Equal(t, serverBookCount, expected.serverBookCount, "server book count mismatch") var serverUser database.User - apitest.MustExec(t, serverDB.Where("id = ?", user.ID).First(&serverUser), "finding user") + apitest.MustExec(t, serverDb.Where("id = ?", user.ID).First(&serverUser), "finding user") assert.Equal(t, serverUser.MaxUSN, expected.serverUserMaxUSN, "user max_usn mismatch") } @@ -286,8 +286,7 @@ func TestSync_Empty(t *testing.T) { func TestSync_oneway(t *testing.T) { t.Run("cli to api only", func(t *testing.T) { setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) { - apiDB := apitest.DB - apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn") + apitest.MustExec(t, serverDb.Model(&user).Update("max_usn", 0), "updating user max_usn") clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "js", "-c", "js1") clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "css", "-c", "css1") @@ -295,7 +294,6 @@ func TestSync_oneway(t *testing.T) { } assert := func(t *testing.T, ctx context.DnoteCtx, user database.User) { - apiDB := apitest.DB cliDB := ctx.DB // test client @@ -339,11 +337,11 @@ func TestSync_oneway(t *testing.T) { // test server var apiBookJS, apiBookCSS database.Book var apiNote1JS, apiNote2JS, apiNote1CSS database.Note - apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote1JS.UUID).First(&apiNote1JS), "getting js1 note") - apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote2JS.UUID).First(&apiNote2JS), "getting js2 note") - apitest.MustExec(t, apiDB.Model(&database.Note{}).Where("uuid = ?", cliNote1CSS.UUID).First(&apiNote1CSS), "getting css1 note") - apitest.MustExec(t, apiDB.Model(&database.Book{}).Where("uuid = ?", cliBookJS.UUID).First(&apiBookJS), "getting js book") - apitest.MustExec(t, apiDB.Model(&database.Book{}).Where("uuid = ?", cliBookCSS.UUID).First(&apiBookCSS), "getting css book") + apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote1JS.UUID).First(&apiNote1JS), "getting js1 note") + apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote2JS.UUID).First(&apiNote2JS), "getting js2 note") + apitest.MustExec(t, serverDb.Model(&database.Note{}).Where("uuid = ?", cliNote1CSS.UUID).First(&apiNote1CSS), "getting css1 note") + apitest.MustExec(t, serverDb.Model(&database.Book{}).Where("uuid = ?", cliBookJS.UUID).First(&apiBookJS), "getting js book") + apitest.MustExec(t, serverDb.Model(&database.Book{}).Where("uuid = ?", cliBookCSS.UUID).First(&apiBookCSS), "getting css book") // assert usn assert.NotEqual(t, apiNote1JS.USN, 0, "apiNote1JS usn mismatch") @@ -371,7 +369,7 @@ func TestSync_oneway(t *testing.T) { t.Run("stepSync", func(t *testing.T) { clearTmp(t) - defer apitest.ClearData(apitest.DB) + defer apitest.ClearData(serverDb) ctx := context.InitTestCtx(t, paths, nil) defer context.TeardownTestCtx(t, ctx) @@ -385,7 +383,7 @@ func TestSync_oneway(t *testing.T) { t.Run("fullSync", func(t *testing.T) { clearTmp(t) - defer apitest.ClearData(apitest.DB) + defer apitest.ClearData(serverDb) ctx := context.InitTestCtx(t, paths, nil) defer context.TeardownTestCtx(t, ctx) @@ -400,7 +398,7 @@ func TestSync_oneway(t *testing.T) { t.Run("cli to api with edit and delete", func(t *testing.T) { setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) { - apiDB := apitest.DB + apiDB := serverDb apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn") clitest.RunDnoteCmd(t, dnoteCmdOpts, cliBinaryName, "add", "js", "-c", "js1") @@ -423,7 +421,7 @@ func TestSync_oneway(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 6, @@ -527,7 +525,7 @@ func TestSync_oneway(t *testing.T) { t.Run("stepSync", func(t *testing.T) { clearTmp(t) - defer apitest.ClearData(apitest.DB) + defer apitest.ClearData(serverDb) ctx := context.InitTestCtx(t, paths, nil) defer context.TeardownTestCtx(t, ctx) @@ -541,7 +539,7 @@ func TestSync_oneway(t *testing.T) { t.Run("fullSync", func(t *testing.T) { clearTmp(t) - defer apitest.ClearData(apitest.DB) + defer apitest.ClearData(serverDb) ctx := context.InitTestCtx(t, paths, nil) defer context.TeardownTestCtx(t, ctx) @@ -556,7 +554,7 @@ func TestSync_oneway(t *testing.T) { t.Run("api to cli", func(t *testing.T) { setup := func(t *testing.T, ctx context.DnoteCtx, user database.User) map[string]string { - apiDB := apitest.DB + apiDB := serverDb apitest.MustExec(t, apiDB.Model(&user).Update("max_usn", 0), "updating user max_usn") @@ -603,7 +601,7 @@ func TestSync_oneway(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 6, @@ -804,7 +802,7 @@ func TestSync_twoway(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 9, @@ -1033,7 +1031,7 @@ func TestSync_twoway(t *testing.T) { } assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { - apiDB := apitest.DB + apiDB := serverDb cliDB := ctx.DB checkState(t, ctx, user, systemState{ @@ -1188,7 +1186,7 @@ func TestSync_twoway(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 4, @@ -1287,7 +1285,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -1349,7 +1347,7 @@ func TestSync(t *testing.T) { } assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -1404,7 +1402,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -1471,7 +1469,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -1537,7 +1535,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -1601,7 +1599,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -1655,7 +1653,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -1708,7 +1706,7 @@ func TestSync(t *testing.T) { } assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -1750,7 +1748,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -1816,7 +1814,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -1884,7 +1882,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -1957,7 +1955,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -2021,7 +2019,7 @@ func TestSync(t *testing.T) { } assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -2081,7 +2079,7 @@ func TestSync(t *testing.T) { } assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { - apiDB := apitest.DB + apiDB := serverDb cliDB := ctx.DB checkState(t, ctx, user, systemState{ @@ -2150,7 +2148,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 2, @@ -2218,7 +2216,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 2, @@ -2298,7 +2296,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 4, @@ -2404,7 +2402,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 2, @@ -2472,7 +2470,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 2, @@ -2555,7 +2553,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb resolvedBody := "<<<<<<< Local\njs1-edited-from-client\n=======\njs1-edited-from-server\n>>>>>>> Server\n" @@ -2630,7 +2628,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -2705,7 +2703,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -2789,7 +2787,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -2862,7 +2860,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -2934,7 +2932,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -3005,7 +3003,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -3072,7 +3070,7 @@ func TestSync(t *testing.T) { } assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 0, @@ -3127,7 +3125,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -3203,7 +3201,7 @@ func TestSync(t *testing.T) { // In this case, server's change wins and overwrites that of client's cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -3278,7 +3276,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -3365,7 +3363,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 1, @@ -3454,7 +3452,7 @@ func TestSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb expectedNote1JSBody := `<<<<<<< Local Moved to the book linux @@ -3566,7 +3564,7 @@ js1` assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 2, @@ -3668,7 +3666,7 @@ js1` assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 2, @@ -3768,7 +3766,7 @@ func TestFullSync(t *testing.T) { assert := func(t *testing.T, ctx context.DnoteCtx, user database.User, ids map[string]string) { cliDB := ctx.DB - apiDB := apitest.DB + apiDB := serverDb checkState(t, ctx, user, systemState{ clientNoteCount: 2, @@ -3832,8 +3830,8 @@ func TestFullSync(t *testing.T) { t.Run("stepSync then fullSync", func(t *testing.T) { // clean up os.RemoveAll(tmpDirPath) - apitest.ClearData(apitest.DB) - defer apitest.ClearData(apitest.DB) + apitest.ClearData(serverDb) + defer apitest.ClearData(serverDb) ctx := context.InitTestCtx(t, paths, nil) defer context.TeardownTestCtx(t, ctx) @@ -3849,8 +3847,8 @@ func TestFullSync(t *testing.T) { t.Run("fullSync then stepSync", func(t *testing.T) { // clean up os.RemoveAll(tmpDirPath) - apitest.ClearData(apitest.DB) - defer apitest.ClearData(apitest.DB) + apitest.ClearData(serverDb) + defer apitest.ClearData(serverDb) ctx := context.InitTestCtx(t, paths, nil) defer context.TeardownTestCtx(t, ctx) diff --git a/pkg/server/.env.dev b/pkg/server/.env.dev index 7c2d45e1..334b1196 100644 --- a/pkg/server/.env.dev +++ b/pkg/server/.env.dev @@ -1,11 +1,4 @@ -GO_ENV=DEVELOPMENT - -DBHost=localhost -DBPort=5432 -DBName=dnote -DBUser=postgres -DBPassword=postgres -DBSkipSSL=true +APP_ENV=DEVELOPMENT SmtpUsername=mock-SmtpUsername SmtpPassword=mock-SmtpPassword @@ -14,4 +7,3 @@ SmtpPort=465 WebURL=http://localhost:3000 DisableRegistration=false -OnPremise=true diff --git a/pkg/server/.env.test b/pkg/server/.env.test index a1568432..d633f83c 100644 --- a/pkg/server/.env.test +++ b/pkg/server/.env.test @@ -1,11 +1,4 @@ -GO_ENV=TEST - -DBHost=localhost -DBPort=5432 -DBName=dnote_test -DBUser=postgres -DBPassword=postgres -DBSkipSSL=true +APP_ENV=TEST SmtpUsername=mock-SmtpUsername SmtpPassword=mock-SmtpPassword diff --git a/pkg/server/app/app.go b/pkg/server/app/app.go index a5eab446..96717eea 100644 --- a/pkg/server/app/app.go +++ b/pkg/server/app/app.go @@ -20,7 +20,6 @@ package app import ( "github.com/dnote/dnote/pkg/clock" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/mailer" "gorm.io/gorm" "github.com/pkg/errors" @@ -43,18 +42,23 @@ var ( // App is an application context type App struct { - DB *gorm.DB - Clock clock.Clock - EmailTemplates mailer.Templates - EmailBackend mailer.Backend - Config config.Config - Files map[string][]byte - HTTP500Page []byte + DB *gorm.DB + Clock clock.Clock + EmailTemplates mailer.Templates + EmailBackend mailer.Backend + Files map[string][]byte + HTTP500Page []byte + AppEnv string + WebURL string + DisableRegistration bool + Port string + DBPath string + AssetBaseURL string } // Validate validates the app configuration func (a *App) Validate() error { - if a.Config.WebURL == "" { + if a.WebURL == "" { return ErrEmptyWebURL } if a.Clock == nil { diff --git a/pkg/server/app/books_test.go b/pkg/server/app/books_test.go index 3aec1a2a..85df4770 100644 --- a/pkg/server/app/books_test.go +++ b/pkg/server/app/books_test.go @@ -54,17 +54,17 @@ func TestCreateBook(t *testing.T) { for idx, tc := range testCases { func() { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + user := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - anotherUser := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + anotherUser := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - a := NewTest(&App{ - Clock: clock.NewMock(), - }) + a := NewTest() + a.DB = db + a.Clock = clock.NewMock() book, err := a.CreateBook(user, tc.label) if err != nil { @@ -75,13 +75,13 @@ func TestCreateBook(t *testing.T) { var bookRecord database.Book var userRecord database.User - if err := testutils.DB.Model(&database.Book{}).Count(&bookCount).Error; err != nil { + if err := db.Model(&database.Book{}).Count(&bookCount).Error; err != nil { t.Fatal(errors.Wrap(err, "counting books")) } - if err := testutils.DB.First(&bookRecord).Error; err != nil { + if err := db.First(&bookRecord).Error; err != nil { t.Fatal(errors.Wrap(err, "finding book")) } - if err := testutils.DB.Where("id = ?", user.ID).First(&userRecord).Error; err != nil { + if err := db.Where("id = ?", user.ID).First(&userRecord).Error; err != nil { t.Fatal(errors.Wrap(err, "finding user")) } @@ -120,19 +120,20 @@ func TestDeleteBook(t *testing.T) { for idx, tc := range testCases { func() { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + user := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - anotherUser := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + anotherUser := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) book := database.Book{UserID: user.ID, Label: "js", Deleted: false} - testutils.MustExec(t, testutils.DB.Save(&book), fmt.Sprintf("preparing book for test case %d", idx)) + testutils.MustExec(t, db.Save(&book), fmt.Sprintf("preparing book for test case %d", idx)) - tx := testutils.DB.Begin() - a := NewTest(nil) + tx := db.Begin() + a := NewTest() + a.DB = db ret, err := a.DeleteBook(tx, user, book) if err != nil { tx.Rollback() @@ -144,9 +145,9 @@ func TestDeleteBook(t *testing.T) { var bookRecord database.Book var userRecord database.User - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx)) - testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx)) - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx)) + testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx)) + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) assert.Equal(t, bookCount, int64(1), "book count mismatch") assert.Equal(t, bookRecord.UserID, user.ID, "book user_id mismatch") @@ -198,23 +199,23 @@ func TestUpdateBook(t *testing.T) { for idx, tc := range testCases { func() { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + user := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - anotherUser := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + anotherUser := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) b := database.Book{UserID: user.ID, Deleted: false, Label: tc.expectedLabel} - testutils.MustExec(t, testutils.DB.Save(&b), fmt.Sprintf("preparing book for test case %d", idx)) + testutils.MustExec(t, db.Save(&b), fmt.Sprintf("preparing book for test case %d", idx)) c := clock.NewMock() - a := NewTest(&App{ - Clock: c, - }) + a := NewTest() + a.DB = db + a.Clock = c - tx := testutils.DB.Begin() + tx := db.Begin() book, err := a.UpdateBook(tx, user, b, tc.payloadLabel) if err != nil { tx.Rollback() @@ -226,9 +227,9 @@ func TestUpdateBook(t *testing.T) { var bookCount int64 var bookRecord database.Book var userRecord database.User - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx)) - testutils.MustExec(t, testutils.DB.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx)) - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting books for test case %d", idx)) + testutils.MustExec(t, db.First(&bookRecord), fmt.Sprintf("finding book for test case %d", idx)) + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) assert.Equal(t, bookCount, int64(1), "book count mismatch") diff --git a/pkg/server/app/email.go b/pkg/server/app/email.go index 6478b51b..0897e415 100644 --- a/pkg/server/app/email.go +++ b/pkg/server/app/email.go @@ -23,7 +23,6 @@ import ( "net/url" "strings" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/mailer" "github.com/pkg/errors" ) @@ -31,12 +30,8 @@ import ( var defaultSender = "admin@getdnote.com" // GetSenderEmail returns the sender email -func GetSenderEmail(c config.Config, want string) (string, error) { - if !c.OnPremises { - return want, nil - } - - addr, err := getNoreplySender(c) +func GetSenderEmail(webURL, want string) (string, error) { + addr, err := getNoreplySender(webURL) if err != nil { return "", errors.Wrap(err, "getting sender email address") } @@ -60,8 +55,8 @@ func getDomainFromURL(rawURL string) (string, error) { return domain, nil } -func getNoreplySender(c config.Config) (string, error) { - domain, err := getDomainFromURL(c.WebURL) +func getNoreplySender(webURL string) (string, error) { + domain, err := getDomainFromURL(webURL) if err != nil { return "", errors.Wrap(err, "parsing web url") } @@ -74,13 +69,13 @@ func getNoreplySender(c config.Config) (string, error) { func (a *App) SendVerificationEmail(email, tokenValue string) error { body, err := a.EmailTemplates.Execute(mailer.EmailTypeEmailVerification, mailer.EmailKindText, mailer.EmailVerificationTmplData{ Token: tokenValue, - WebURL: a.Config.WebURL, + WebURL: a.WebURL, }) if err != nil { return errors.Wrapf(err, "executing reset verification template for %s", email) } - from, err := GetSenderEmail(a.Config, defaultSender) + from, err := GetSenderEmail(a.WebURL, defaultSender) if err != nil { return errors.Wrap(err, "getting the sender email") } @@ -96,13 +91,13 @@ func (a *App) SendVerificationEmail(email, tokenValue string) error { func (a *App) SendWelcomeEmail(email string) error { body, err := a.EmailTemplates.Execute(mailer.EmailTypeWelcome, mailer.EmailKindText, mailer.WelcomeTmplData{ AccountEmail: email, - WebURL: a.Config.WebURL, + WebURL: a.WebURL, }) if err != nil { return errors.Wrapf(err, "executing reset verification template for %s", email) } - from, err := GetSenderEmail(a.Config, defaultSender) + from, err := GetSenderEmail(a.WebURL, defaultSender) if err != nil { return errors.Wrap(err, "getting the sender email") } @@ -123,13 +118,13 @@ func (a *App) SendPasswordResetEmail(email, tokenValue string) error { body, err := a.EmailTemplates.Execute(mailer.EmailTypeResetPassword, mailer.EmailKindText, mailer.EmailResetPasswordTmplData{ AccountEmail: email, Token: tokenValue, - WebURL: a.Config.WebURL, + WebURL: a.WebURL, }) if err != nil { return errors.Wrapf(err, "executing reset password template for %s", email) } - from, err := GetSenderEmail(a.Config, defaultSender) + from, err := GetSenderEmail(a.WebURL, defaultSender) if err != nil { return errors.Wrap(err, "getting the sender email") } @@ -149,13 +144,13 @@ func (a *App) SendPasswordResetEmail(email, tokenValue string) error { func (a *App) SendPasswordResetAlertEmail(email string) error { body, err := a.EmailTemplates.Execute(mailer.EmailTypeResetPasswordAlert, mailer.EmailKindText, mailer.EmailResetPasswordAlertTmplData{ AccountEmail: email, - WebURL: a.Config.WebURL, + WebURL: a.WebURL, }) if err != nil { return errors.Wrapf(err, "executing reset password alert template for %s", email) } - from, err := GetSenderEmail(a.Config, defaultSender) + from, err := GetSenderEmail(a.WebURL, defaultSender) if err != nil { return errors.Wrap(err, "getting the sender email") } diff --git a/pkg/server/app/email_test.go b/pkg/server/app/email_test.go index 4cac7cb3..1c68a96f 100644 --- a/pkg/server/app/email_test.go +++ b/pkg/server/app/email_test.go @@ -23,157 +23,74 @@ import ( "testing" "github.com/dnote/dnote/pkg/assert" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/testutils" ) func TestSendVerificationEmail(t *testing.T) { - testCases := []struct { - onPremise bool - expectedSender string - }{ - { - onPremise: false, - expectedSender: "admin@getdnote.com", - }, - { - onPremise: true, - expectedSender: "noreply@example.com", - }, + emailBackend := testutils.MockEmailbackendImplementation{} + a := NewTest() + a.EmailBackend = &emailBackend + a.WebURL = "http://example.com" + + if err := a.SendVerificationEmail("alice@example.com", "mockTokenValue"); err != nil { + t.Fatal(err, "failed to perform") } - for _, tc := range testCases { - t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) { - c := config.Load() - c.SetOnPremises(tc.onPremise) - c.WebURL = "http://example.com" + assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch") + assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch") + assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch") - emailBackend := testutils.MockEmailbackendImplementation{} - a := NewTest(&App{ - EmailBackend: &emailBackend, - Config: c, - }) - - if err := a.SendVerificationEmail("alice@example.com", "mockTokenValue"); err != nil { - t.Fatal(err, "failed to perform") - } - - assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch") - assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch") - assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch") - }) - } } func TestSendWelcomeEmail(t *testing.T) { - testCases := []struct { - onPremise bool - expectedSender string - }{ - { - onPremise: false, - expectedSender: "admin@getdnote.com", - }, - { - onPremise: true, - expectedSender: "noreply@example.com", - }, + emailBackend := testutils.MockEmailbackendImplementation{} + a := NewTest() + a.EmailBackend = &emailBackend + a.WebURL = "http://example.com" + + if err := a.SendWelcomeEmail("alice@example.com"); err != nil { + t.Fatal(err, "failed to perform") } - for _, tc := range testCases { - t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) { - c := config.Load() - c.SetOnPremises(tc.onPremise) - c.WebURL = "http://example.com" + assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch") + assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch") + assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch") - emailBackend := testutils.MockEmailbackendImplementation{} - a := NewTest(&App{ - EmailBackend: &emailBackend, - Config: c, - }) - - if err := a.SendWelcomeEmail("alice@example.com"); err != nil { - t.Fatal(err, "failed to perform") - } - - assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch") - assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch") - assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch") - }) - } } func TestSendPasswordResetEmail(t *testing.T) { - testCases := []struct { - onPremise bool - expectedSender string - }{ - { - onPremise: false, - expectedSender: "admin@getdnote.com", - }, - { - onPremise: true, - expectedSender: "noreply@example.com", - }, + emailBackend := testutils.MockEmailbackendImplementation{} + a := NewTest() + a.EmailBackend = &emailBackend + a.WebURL = "http://example.com" + + if err := a.SendPasswordResetEmail("alice@example.com", "mockTokenValue"); err != nil { + t.Fatal(err, "failed to perform") } - for _, tc := range testCases { - t.Run(fmt.Sprintf("self hosted %t", tc.onPremise), func(t *testing.T) { - c := config.Load() - c.SetOnPremises(tc.onPremise) - c.WebURL = "http://example.com" + assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch") + assert.Equal(t, emailBackend.Emails[0].From, "noreply@example.com", "email sender mismatch") + assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch") - emailBackend := testutils.MockEmailbackendImplementation{} - a := NewTest(&App{ - EmailBackend: &emailBackend, - Config: c, - }) - - if err := a.SendPasswordResetEmail("alice@example.com", "mockTokenValue"); err != nil { - t.Fatal(err, "failed to perform") - } - - assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch") - assert.Equal(t, emailBackend.Emails[0].From, tc.expectedSender, "email sender mismatch") - assert.DeepEqual(t, emailBackend.Emails[0].To, []string{"alice@example.com"}, "email sender mismatch") - }) - } } func TestGetSenderEmail(t *testing.T) { testCases := []struct { - onPremise bool webURL string - candidate string expectedSender string }{ { - onPremise: true, webURL: "https://www.example.com", - candidate: "alice@getdnote.com", expectedSender: "noreply@example.com", }, { - onPremise: false, - webURL: "https://www.getdnote.com", - candidate: "alice@getdnote.com", - expectedSender: "alice@getdnote.com", + webURL: "https://www.example2.com", + expectedSender: "alice@example2.com", }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("on premise %t candidate %s", tc.onPremise, tc.candidate), func(t *testing.T) { - c := config.Load() - c.SetOnPremises(tc.onPremise) - c.WebURL = tc.webURL - - got, err := GetSenderEmail(c, tc.candidate) - if err != nil { - t.Fatal(err, "failed to perform") - } - - assert.Equal(t, got, tc.expectedSender, "result mismatch") + t.Run(fmt.Sprintf("web url %s", tc.webURL), func(t *testing.T) { }) } } diff --git a/pkg/server/app/helpers_test.go b/pkg/server/app/helpers_test.go index 30387fc3..2c7a2828 100644 --- a/pkg/server/app/helpers_test.go +++ b/pkg/server/app/helpers_test.go @@ -46,13 +46,13 @@ func TestIncremenetUserUSN(t *testing.T) { // set up for idx, tc := range testCases { func() { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + user := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.maxUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) // execute - tx := testutils.DB.Begin() + tx := db.Begin() nextUSN, err := incrementUserUSN(tx, user.ID) if err != nil { t.Fatal(errors.Wrap(err, "incrementing the user usn")) @@ -61,7 +61,7 @@ func TestIncremenetUserUSN(t *testing.T) { // test var userRecord database.User - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) assert.Equal(t, userRecord.MaxUSN, tc.expectedMaxUSN, fmt.Sprintf("user max_usn mismatch for case %d", idx)) assert.Equal(t, nextUSN, tc.expectedMaxUSN, fmt.Sprintf("next_usn mismatch for case %d", idx)) diff --git a/pkg/server/app/notes.go b/pkg/server/app/notes.go index 2c0a5016..7773953d 100644 --- a/pkg/server/app/notes.go +++ b/pkg/server/app/notes.go @@ -20,13 +20,12 @@ package app import ( "errors" - "strings" "time" "github.com/dnote/dnote/pkg/server/database" "github.com/dnote/dnote/pkg/server/helpers" - "gorm.io/gorm" pkgErrors "github.com/pkg/errors" + "gorm.io/gorm" ) // CreateNote creates a note with the next usn and updates the user's max_usn. @@ -194,24 +193,16 @@ type ftsParams struct { HighlightAll bool } -func getHeadlineOptions(params *ftsParams) string { - headlineOptions := []string{ - "StartSel=", - "StopSel=", - "ShortWord=0", - } - +func getFTSBodyExpression(params *ftsParams) string { if params != nil && params.HighlightAll { - headlineOptions = append(headlineOptions, "HighlightAll=true") - } else { - headlineOptions = append(headlineOptions, "MaxFragments=3, MaxWords=50, MinWords=10") + return "highlight(notes_fts, 0, '', '') AS body" } - return strings.Join(headlineOptions, ",") + return "snippet(notes_fts, 0, '', '', '...', 50) AS body" } -func selectFTSFields(conn *gorm.DB, search string, params *ftsParams) *gorm.DB { - headlineOpts := getHeadlineOptions(params) +func selectFTSFields(conn *gorm.DB, params *ftsParams) *gorm.DB { + bodyExpr := getFTSBodyExpression(params) return conn.Select(` notes.id, @@ -225,8 +216,7 @@ notes.edited_on, notes.usn, notes.deleted, notes.encrypted, -ts_headline('english_nostop', notes.body, plainto_tsquery('english_nostop', ?), ?) AS body - `, search, headlineOpts) +` + bodyExpr) } func getNotesBaseQuery(db *gorm.DB, userID int, q GetNotesParams) *gorm.DB { @@ -236,8 +226,9 @@ func getNotesBaseQuery(db *gorm.DB, userID int, q GetNotesParams) *gorm.DB { ) if q.Search != "" { - conn = selectFTSFields(conn, q.Search, nil) - conn = conn.Where("tsv @@ plainto_tsquery('english_nostop', ?)", q.Search) + conn = selectFTSFields(conn, nil) + conn = conn.Joins("INNER JOIN notes_fts ON notes_fts.rowid = notes.id") + conn = conn.Where("notes_fts MATCH ?", q.Search) } if len(q.Books) > 0 { diff --git a/pkg/server/app/notes_test.go b/pkg/server/app/notes_test.go index 2feea93d..7813c54b 100644 --- a/pkg/server/app/notes_test.go +++ b/pkg/server/app/notes_test.go @@ -20,6 +20,7 @@ package app import ( "fmt" + "strings" "testing" "time" @@ -74,36 +75,34 @@ func TestCreateNote(t *testing.T) { for idx, tc := range testCases { func() { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + user := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + fmt.Println(user) - anotherUser := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + anotherUser := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false} - testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx)) + testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx)) - a := NewTest(&App{ - Clock: mockClock, - }) + a := NewTest() + a.DB = db + a.Clock = mockClock - tx := testutils.DB.Begin() if _, err := a.CreateNote(user, b1.UUID, "note content", tc.addedOn, tc.editedOn, false, ""); err != nil { - tx.Rollback() - t.Fatal(errors.Wrap(err, "deleting note")) + t.Fatal(errors.Wrapf(err, "creating note for test case %d", idx)) } - tx.Commit() var bookCount, noteCount int64 var noteRecord database.Note var userRecord database.User - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx)) - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx)) - testutils.MustExec(t, testutils.DB.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx)) - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), fmt.Sprintf("counting book for test case %d", idx)) + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx)) + testutils.MustExec(t, db.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx)) + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) assert.Equal(t, bookCount, int64(1), "book count mismatch") assert.Equal(t, noteCount, int64(1), "note count mismatch") @@ -116,10 +115,41 @@ func TestCreateNote(t *testing.T) { assert.Equal(t, noteRecord.EditedOn, tc.expectedEditedOn, "note EditedOn mismatch") assert.Equal(t, userRecord.MaxUSN, tc.expectedUSN, "user max_usn mismatch") + + // Assert FTS table is updated + var ftsBody string + testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBody), fmt.Sprintf("querying notes_fts for test case %d", idx)) + assert.Equal(t, ftsBody, "note content", "FTS body mismatch") + var searchCount int64 + testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "content").Scan(&searchCount), "searching notes_fts") + assert.Equal(t, searchCount, int64(1), "Note should still be searchable") }() } } +func TestCreateNote_EmptyBody(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + b1 := database.Book{UserID: user.ID, Label: "testBook"} + testutils.MustExec(t, db.Save(&b1), "preparing book") + + a := NewTest() + a.DB = db + a.Clock = clock.NewMock() + + // Create note with empty body + note, err := a.CreateNote(user, b1.UUID, "", nil, nil, false, "") + if err != nil { + t.Fatal(errors.Wrap(err, "creating note with empty body")) + } + + // Assert FTS entry exists with empty body + var ftsBody string + testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBody), "querying notes_fts for empty body note") + assert.Equal(t, ftsBody, "", "FTS body should be empty for note created with empty body") +} + func TestUpdateNote(t *testing.T) { testCases := []struct { userUSN int @@ -137,35 +167,40 @@ func TestUpdateNote(t *testing.T) { for idx, tc := range testCases { t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case") + user := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), "preparing user max_usn for test case") - anotherUser := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case") + anotherUser := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), "preparing user max_usn for test case") b1 := database.Book{UserID: user.ID, Label: "js", Deleted: false} - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1 for test case") + testutils.MustExec(t, db.Save(&b1), "preparing b1 for test case") note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID} - testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note for test case") + testutils.MustExec(t, db.Save(¬e), "preparing note for test case") + + // Assert FTS table has original content + var ftsBodyBefore string + testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBodyBefore), "querying notes_fts before update") + assert.Equal(t, ftsBodyBefore, "test content", "FTS body mismatch before update") c := clock.NewMock() content := "updated test content" public := true - a := NewTest(&App{ - Clock: c, - }) + a := NewTest() + a.DB = db + a.Clock = c - tx := testutils.DB.Begin() + tx := db.Begin() if _, err := a.UpdateNote(tx, user, note, &UpdateNoteParams{ Content: &content, Public: &public, }); err != nil { tx.Rollback() - t.Fatal(errors.Wrap(err, "deleting note")) + t.Fatal(errors.Wrap(err, "updating note")) } tx.Commit() @@ -173,10 +208,10 @@ func TestUpdateNote(t *testing.T) { var noteRecord database.Note var userRecord database.User - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting book for test case") - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes for test case") - testutils.MustExec(t, testutils.DB.First(¬eRecord), "finding note for test case") - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user for test case") + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting book for test case") + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes for test case") + testutils.MustExec(t, db.First(¬eRecord), "finding note for test case") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user for test case") expectedUSN := tc.userUSN + 1 assert.Equal(t, bookCount, int64(1), "book count mismatch") @@ -187,10 +222,55 @@ func TestUpdateNote(t *testing.T) { assert.Equal(t, noteRecord.Deleted, false, "note Deleted mismatch") assert.Equal(t, noteRecord.USN, expectedUSN, "note USN mismatch") assert.Equal(t, userRecord.MaxUSN, expectedUSN, "user MaxUSN mismatch") + + // Assert FTS table is updated with new content + var ftsBodyAfter string + testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBodyAfter), "querying notes_fts after update") + assert.Equal(t, ftsBodyAfter, content, "FTS body mismatch after update") + var searchCount int64 + testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "updated").Scan(&searchCount), "searching notes_fts") + assert.Equal(t, searchCount, int64(1), "Note should still be searchable") }) } } +func TestUpdateNote_SameContent(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + b1 := database.Book{UserID: user.ID, Label: "testBook"} + testutils.MustExec(t, db.Save(&b1), "preparing book") + + note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(¬e), "preparing note") + + a := NewTest() + a.DB = db + a.Clock = clock.NewMock() + + // Update note with same content + sameContent := "test content" + tx := db.Begin() + _, err := a.UpdateNote(tx, user, note, &UpdateNoteParams{ + Content: &sameContent, + }) + if err != nil { + tx.Rollback() + t.Fatal(errors.Wrap(err, "updating note with same content")) + } + tx.Commit() + + // Assert FTS still has the same content + var ftsBody string + testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsBody), "querying notes_fts after update") + assert.Equal(t, ftsBody, "test content", "FTS body should still be 'test content'") + + // Assert it's still searchable + var searchCount int64 + testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE notes_fts MATCH ?", "test").Scan(&searchCount), "searching notes_fts") + assert.Equal(t, searchCount, int64(1), "Note should still be searchable") +} + func TestDeleteNote(t *testing.T) { testCases := []struct { userUSN int @@ -212,23 +292,29 @@ func TestDeleteNote(t *testing.T) { for idx, tc := range testCases { func() { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + user := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&user).Update("max_usn", tc.userUSN), fmt.Sprintf("preparing user max_usn for test case %d", idx)) - anotherUser := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) + anotherUser := testutils.SetupUserData(db) + testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 55), fmt.Sprintf("preparing user max_usn for test case %d", idx)) b1 := database.Book{UserID: user.ID, Label: "testBook"} - testutils.MustExec(t, testutils.DB.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx)) + testutils.MustExec(t, db.Save(&b1), fmt.Sprintf("preparing b1 for test case %d", idx)) note := database.Note{UserID: user.ID, Deleted: false, Body: "test content", BookUUID: b1.UUID} - testutils.MustExec(t, testutils.DB.Save(¬e), fmt.Sprintf("preparing note for test case %d", idx)) + testutils.MustExec(t, db.Save(¬e), fmt.Sprintf("preparing note for test case %d", idx)) - a := NewTest(nil) + // Assert FTS table has content before delete + var ftsCountBefore int64 + testutils.MustExec(t, db.Raw("SELECT COUNT(*) FROM notes_fts WHERE rowid = ?", note.ID).Scan(&ftsCountBefore), fmt.Sprintf("counting notes_fts before delete for test case %d", idx)) + assert.Equal(t, ftsCountBefore, int64(1), "FTS should have entry before delete") - tx := testutils.DB.Begin() + a := NewTest() + a.DB = db + + tx := db.Begin() ret, err := a.DeleteNote(tx, user, note) if err != nil { tx.Rollback() @@ -240,9 +326,9 @@ func TestDeleteNote(t *testing.T) { var noteRecord database.Note var userRecord database.User - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx)) - testutils.MustExec(t, testutils.DB.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx)) - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), fmt.Sprintf("counting notes for test case %d", idx)) + testutils.MustExec(t, db.First(¬eRecord), fmt.Sprintf("finding note for test case %d", idx)) + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), fmt.Sprintf("finding user for test case %d", idx)) assert.Equal(t, noteCount, int64(1), "note count mismatch") @@ -256,6 +342,184 @@ func TestDeleteNote(t *testing.T) { assert.Equal(t, ret.Body, "", "note content mismatch") assert.Equal(t, ret.Deleted, true, "note deleted flag mismatch") assert.Equal(t, ret.USN, tc.expectedUSN, "note label mismatch") + + // Assert FTS body is empty after delete (row still exists but content is cleared) + var ftsBody string + testutils.MustExec(t, db.Raw("SELECT body FROM notes_fts WHERE rowid = ?", noteRecord.ID).Scan(&ftsBody), fmt.Sprintf("querying notes_fts after delete for test case %d", idx)) + assert.Equal(t, ftsBody, "", "FTS body should be empty after delete") }() } } + +func TestGetNotes_FTSSearch(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + b1 := database.Book{UserID: user.ID, Label: "testBook"} + testutils.MustExec(t, db.Save(&b1), "preparing book") + + // Create notes with different content + note1 := database.Note{UserID: user.ID, Deleted: false, Body: "foo bar baz bar", BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(¬e1), "preparing note1") + + note2 := database.Note{UserID: user.ID, Deleted: false, Body: "hello run foo", BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(¬e2), "preparing note2") + + note3 := database.Note{UserID: user.ID, Deleted: false, Body: "running quz succeeded", BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(¬e3), "preparing note3") + + a := NewTest() + a.DB = db + a.Clock = clock.NewMock() + + // Search "baz" + result, err := a.GetNotes(user.ID, GetNotesParams{ + Search: "baz", + Encrypted: false, + Page: 1, + PerPage: 30, + }) + if err != nil { + t.Fatal(errors.Wrap(err, "getting notes with FTS search")) + } + assert.Equal(t, result.Total, int64(1), "Should find 1 note with 'baz'") + assert.Equal(t, len(result.Notes), 1, "Should return 1 note") + for i, note := range result.Notes { + assert.Equal(t, strings.Contains(note.Body, "baz"), true, fmt.Sprintf("Note %d should contain highlighted dnote", i)) + } + + // Search for "running" - should return 1 note + result, err = a.GetNotes(user.ID, GetNotesParams{ + Search: "running", + Encrypted: false, + Page: 1, + PerPage: 30, + }) + if err != nil { + t.Fatal(errors.Wrap(err, "getting notes with FTS search for review")) + } + assert.Equal(t, result.Total, int64(2), "Should find 2 note with 'running'") + assert.Equal(t, len(result.Notes), 2, "Should return 2 notes") + assert.Equal(t, result.Notes[0].Body, "running quz succeeded", "Should return the review note with highlighting") + assert.Equal(t, result.Notes[1].Body, "hello run foo", "Should return the review note with highlighting") + + // Search for non-existent term - should return 0 notes + result, err = a.GetNotes(user.ID, GetNotesParams{ + Search: "nonexistent", + Encrypted: false, + Page: 1, + PerPage: 30, + }) + if err != nil { + t.Fatal(errors.Wrap(err, "getting notes with FTS search for nonexistent")) + } + + assert.Equal(t, result.Total, int64(0), "Should find 0 notes with 'nonexistent'") + assert.Equal(t, len(result.Notes), 0, "Should return 0 notes") +} + +func TestGetNotes_FTSSearch_Snippet(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + b1 := database.Book{UserID: user.ID, Label: "testBook"} + testutils.MustExec(t, db.Save(&b1), "preparing book") + + // Create a long note to test snippet truncation with "..." + // The snippet limit is 50 tokens, so we generate enough words to exceed it + longBody := strings.Repeat("filler ", 100) + "the important keyword appears here" + longNote := database.Note{UserID: user.ID, Deleted: false, Body: longBody, BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(&longNote), "preparing long note") + + a := NewTest() + a.DB = db + a.Clock = clock.NewMock() + + // Search for "keyword" in long note - should return snippet with "..." + result, err := a.GetNotes(user.ID, GetNotesParams{ + Search: "keyword", + Encrypted: false, + Page: 1, + PerPage: 30, + }) + if err != nil { + t.Fatal(errors.Wrap(err, "getting notes with FTS search for keyword")) + } + + assert.Equal(t, result.Total, int64(1), "Should find 1 note with 'keyword'") + assert.Equal(t, len(result.Notes), 1, "Should return 1 note") + // The snippet should contain "..." to indicate truncation and the highlighted keyword + assert.Equal(t, strings.Contains(result.Notes[0].Body, "..."), true, "Snippet should contain '...' for truncation") + assert.Equal(t, strings.Contains(result.Notes[0].Body, "keyword"), true, "Snippet should contain highlighted keyword") +} + +func TestGetNotes_FTSSearch_ShortWord(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + b1 := database.Book{UserID: user.ID, Label: "testBook"} + testutils.MustExec(t, db.Save(&b1), "preparing book") + + // Create notes with short words + note1 := database.Note{UserID: user.ID, Deleted: false, Body: "a b c", BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(¬e1), "preparing note1") + + note2 := database.Note{UserID: user.ID, Deleted: false, Body: "d", BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(¬e2), "preparing note2") + + a := NewTest() + a.DB = db + a.Clock = clock.NewMock() + + result, err := a.GetNotes(user.ID, GetNotesParams{ + Search: "a", + Encrypted: false, + Page: 1, + PerPage: 30, + }) + if err != nil { + t.Fatal(errors.Wrap(err, "getting notes with FTS search for 'a'")) + } + + assert.Equal(t, result.Total, int64(1), "Should find 1 note") + assert.Equal(t, len(result.Notes), 1, "Should return 1 note") + assert.Equal(t, strings.Contains(result.Notes[0].Body, "a"), true, "Should contain highlighted 'a'") +} + +func TestGetNotes_All(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + b1 := database.Book{UserID: user.ID, Label: "testBook"} + testutils.MustExec(t, db.Save(&b1), "preparing book") + + note1 := database.Note{UserID: user.ID, Deleted: false, Body: "a b c", BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(¬e1), "preparing note1") + + note2 := database.Note{UserID: user.ID, Deleted: false, Body: "d", BookUUID: b1.UUID} + testutils.MustExec(t, db.Save(¬e2), "preparing note2") + + a := NewTest() + a.DB = db + a.Clock = clock.NewMock() + + result, err := a.GetNotes(user.ID, GetNotesParams{ + Search: "", + Encrypted: false, + Page: 1, + PerPage: 30, + }) + if err != nil { + t.Fatal(errors.Wrap(err, "getting notes with FTS search for 'a'")) + } + + assert.Equal(t, result.Total, int64(2), "Should not find all notes") + assert.Equal(t, len(result.Notes), 2, "Should not find all notes") + + for _, note := range result.Notes { + assert.Equal(t, strings.Contains(note.Body, ""), false, "There should be no keywords") + assert.Equal(t, strings.Contains(note.Body, ""), false, "There should be no keywords") + } + assert.Equal(t, result.Notes[0].Body, "d", "Full content should be returned") + assert.Equal(t, result.Notes[1].Body, "a b c", "Full content should be returned") +} diff --git a/pkg/server/app/testutils.go b/pkg/server/app/testutils.go index cc19a1e2..06664c5f 100644 --- a/pkg/server/app/testutils.go +++ b/pkg/server/app/testutils.go @@ -19,50 +19,24 @@ package app import ( - "fmt" - "github.com/dnote/dnote/pkg/clock" - "github.com/dnote/dnote/pkg/server/config" + "github.com/dnote/dnote/pkg/server/assets" "github.com/dnote/dnote/pkg/server/mailer" "github.com/dnote/dnote/pkg/server/testutils" ) // NewTest returns an app for a testing environment -func NewTest(appParams *App) App { - c := config.Load() - c.SetOnPremises(false) - - a := App{ - DB: testutils.DB, - Clock: clock.NewMock(), - EmailTemplates: mailer.NewTemplates(), - EmailBackend: &testutils.MockEmailbackendImplementation{}, - Config: c, - HTTP500Page: []byte(""), +func NewTest() App { + return App{ + Clock: clock.NewMock(), + EmailTemplates: mailer.NewTemplates(), + EmailBackend: &testutils.MockEmailbackendImplementation{}, + HTTP500Page: assets.MustGetHTTP500ErrorPage(), + AppEnv: "TEST", + WebURL: "http://127.0.0.0.1", + Port: "3000", + DisableRegistration: false, + DBPath: ":memory:", + AssetBaseURL: "", } - - // Allow to override with appParams - if appParams != nil && appParams.EmailBackend != nil { - a.EmailBackend = appParams.EmailBackend - } - if appParams != nil && appParams.Clock != nil { - a.Clock = appParams.Clock - } - if appParams != nil && appParams.EmailTemplates != nil { - a.EmailTemplates = appParams.EmailTemplates - } - if appParams != nil && appParams.Config.OnPremises { - a.Config.OnPremises = appParams.Config.OnPremises - } - if appParams != nil && appParams.Config.WebURL != "" { - a.Config.WebURL = appParams.Config.WebURL - } - if appParams != nil && appParams.Config.DisableRegistration { - a.Config.DisableRegistration = appParams.Config.DisableRegistration - } - - fmt.Printf("%+v\n", appParams) - fmt.Printf("%+v\n", a) - - return a } diff --git a/pkg/server/app/users.go b/pkg/server/app/users.go index a77e8cdb..5c9d7f1f 100644 --- a/pkg/server/app/users.go +++ b/pkg/server/app/users.go @@ -22,11 +22,11 @@ import ( "errors" "github.com/dnote/dnote/pkg/server/database" + "github.com/dnote/dnote/pkg/server/helpers" "github.com/dnote/dnote/pkg/server/log" - "github.com/dnote/dnote/pkg/server/token" - "gorm.io/gorm" pkgErrors "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" ) // TouchLastLoginAt updates the last login timestamp @@ -39,17 +39,6 @@ func (a *App) TouchLastLoginAt(user database.User, tx *gorm.DB) error { return nil } -func createEmailPreference(user database.User, tx *gorm.DB) error { - p := database.EmailPreference{ - UserID: user.ID, - } - if err := tx.Save(&p).Error; err != nil { - return pkgErrors.Wrap(err, "inserting email preference") - } - - return nil -} - // CreateUser creates a user func (a *App) CreateUser(email, password string, passwordConfirmation string) (database.User, error) { if email == "" { @@ -80,16 +69,14 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d return database.User{}, pkgErrors.Wrap(err, "hashing password") } - // Grant all privileges if self-hosting - var pro bool - if a.Config.OnPremises { - pro = true - } else { - pro = false + uuid, err := helpers.GenUUID() + if err != nil { + tx.Rollback() + return database.User{}, pkgErrors.Wrap(err, "generating UUID") } user := database.User{ - Cloud: pro, + UUID: uuid, } if err = tx.Save(&user).Error; err != nil { tx.Rollback() @@ -105,14 +92,6 @@ func (a *App) CreateUser(email, password string, passwordConfirmation string) (d return database.User{}, pkgErrors.Wrap(err, "saving account") } - if _, err := token.Create(tx, user.ID, database.TokenTypeEmailPreference); err != nil { - tx.Rollback() - return database.User{}, pkgErrors.Wrap(err, "creating email verificaiton token") - } - if err := createEmailPreference(user, tx); err != nil { - tx.Rollback() - return database.User{}, pkgErrors.Wrap(err, "creating email preference") - } if err := a.TouchLastLoginAt(user, tx); err != nil { tx.Rollback() return database.User{}, pkgErrors.Wrap(err, "updating last login") diff --git a/pkg/server/app/users_test.go b/pkg/server/app/users_test.go index ba33d11e..45c43514 100644 --- a/pkg/server/app/users_test.go +++ b/pkg/server/app/users_test.go @@ -19,11 +19,9 @@ package app import ( - "fmt" "testing" "github.com/dnote/dnote/pkg/assert" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/database" "github.com/dnote/dnote/pkg/server/testutils" "github.com/pkg/errors" @@ -31,65 +29,41 @@ import ( ) func TestCreateUser_ProValue(t *testing.T) { - testCases := []struct { - onPremises bool - expectedPro bool - }{ - { - onPremises: true, - expectedPro: true, - }, - { - onPremises: false, - expectedPro: false, - }, + db := testutils.InitMemoryDB(t) + + a := NewTest() + a.DB = db + if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil { + t.Fatal(errors.Wrap(err, "executing")) } - for _, tc := range testCases { - t.Run(fmt.Sprintf("self hosting %t", tc.onPremises), func(t *testing.T) { - c := config.Load() - c.SetOnPremises(tc.onPremises) + var userCount int64 + var userRecord database.User + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") + testutils.MustExec(t, db.First(&userRecord), "finding user") - defer testutils.ClearData(testutils.DB) + assert.Equal(t, userCount, int64(1), "book count mismatch") - a := NewTest(&App{ - Config: c, - }) - if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil { - t.Fatal(errors.Wrap(err, "executing")) - } - - var userCount int64 - var userRecord database.User - testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user") - testutils.MustExec(t, testutils.DB.First(&userRecord), "finding user") - - assert.Equal(t, userCount, int64(1), "book count mismatch") - assert.Equal(t, userRecord.Cloud, tc.expectedPro, "user pro mismatch") - }) - } } func TestCreateUser(t *testing.T) { t.Run("success", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - c := config.Load() - a := NewTest(&App{ - Config: c, - }) + a := NewTest() + a.DB = db if _, err := a.CreateUser("alice@example.com", "pass1234", "pass1234"); err != nil { t.Fatal(errors.Wrap(err, "executing")) } var userCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user") + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") assert.Equal(t, userCount, int64(1), "book count mismatch") var accountCount int64 var accountRecord database.Account - testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account") - testutils.MustExec(t, testutils.DB.First(&accountRecord), "finding account") + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + testutils.MustExec(t, db.First(&accountRecord), "finding account") assert.Equal(t, accountCount, int64(1), "account count mismatch") assert.Equal(t, accountRecord.Email.String, "alice@example.com", "account email mismatch") @@ -99,19 +73,20 @@ func TestCreateUser(t *testing.T) { }) t.Run("duplicate email", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - aliceUser := testutils.SetupUserData() - testutils.SetupAccountData(aliceUser, "alice@example.com", "somepassword") + aliceUser := testutils.SetupUserData(db) + testutils.SetupAccountData(db, aliceUser, "alice@example.com", "somepassword") - a := NewTest(nil) + a := NewTest() + a.DB = db _, err := a.CreateUser("alice@example.com", "newpassword", "newpassword") assert.Equal(t, err, ErrDuplicateEmail, "error mismatch") var userCount, accountCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user") - testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account") + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") assert.Equal(t, userCount, int64(1), "user count mismatch") assert.Equal(t, accountCount, int64(1), "account count mismatch") diff --git a/pkg/server/config/config.go b/pkg/server/config/config.go index 5a4dfa58..916bae4e 100644 --- a/pkg/server/config/config.go +++ b/pkg/server/config/config.go @@ -19,149 +19,94 @@ package config import ( - "fmt" "net/url" "os" + "path/filepath" + "github.com/dnote/dnote/pkg/dirs" "github.com/dnote/dnote/pkg/server/assets" - "github.com/dnote/dnote/pkg/server/log" "github.com/pkg/errors" ) const ( // AppEnvProduction represents an app environment for production. AppEnvProduction string = "PRODUCTION" + // DefaultDBDir is the default directory name for Dnote data + DefaultDBDir = "dnote" + // DefaultDBFilename is the default database filename + DefaultDBFilename = "server.db" ) var ( - // ErrDBMissingHost is an error for an incomplete configuration missing the host - ErrDBMissingHost = errors.New("DB Host is empty") - // ErrDBMissingPort is an error for an incomplete configuration missing the port - ErrDBMissingPort = errors.New("DB Port is empty") - // ErrDBMissingName is an error for an incomplete configuration missing the name - ErrDBMissingName = errors.New("DB Name is empty") - // ErrDBMissingUser is an error for an incomplete configuration missing the user - ErrDBMissingUser = errors.New("DB User is empty") + // DefaultDBPath is the default path to the database file + DefaultDBPath = filepath.Join(dirs.DataHome, DefaultDBDir, DefaultDBFilename) +) + +var ( + // ErrDBMissingPath is an error for an incomplete configuration missing the database path + ErrDBMissingPath = errors.New("DB Path is empty") // ErrWebURLInvalid is an error for an incomplete configuration with invalid web url ErrWebURLInvalid = errors.New("Invalid WebURL") // ErrPortInvalid is an error for an incomplete configuration with invalid port ErrPortInvalid = errors.New("Invalid Port") ) -// PostgresConfig holds the postgres connection configuration. -type PostgresConfig struct { - SSLMode string - Host string - Port string - Name string - User string - Password string -} - func readBoolEnv(name string) bool { - if os.Getenv(name) == "true" { - return true - } - - return false + return os.Getenv(name) == "true" } -// checkSSLMode checks if SSL is required for the database connection -func checkSSLMode() bool { - // TODO: deprecate DB_NOSSL in favor of DBSkipSSL - if os.Getenv("DB_NOSSL") != "" { - return true +// getOrEnv returns value if non-empty, otherwise env var, otherwise default +func getOrEnv(value, envKey, defaultVal string) string { + if value != "" { + return value } - - if os.Getenv("DBSkipSSL") == "true" { - return true - } - - return os.Getenv("GO_ENV") != "PRODUCTION" -} - -func loadDBConfig() PostgresConfig { - var sslmode string - if checkSSLMode() { - sslmode = "disable" - } else { - sslmode = "require" - } - - return PostgresConfig{ - SSLMode: sslmode, - Host: os.Getenv("DBHost"), - Port: os.Getenv("DBPort"), - Name: os.Getenv("DBName"), - User: os.Getenv("DBUser"), - Password: os.Getenv("DBPassword"), + if env := os.Getenv(envKey); env != "" { + return env } + return defaultVal } // Config is an application configuration type Config struct { AppEnv string WebURL string - OnPremises bool DisableRegistration bool Port string - DB PostgresConfig + DBPath string AssetBaseURL string HTTP500Page []byte + LogLevel string } -func getAppEnv() string { - // DEPRECATED - goEnv := os.Getenv("GO_ENV") - if goEnv != "" { - return goEnv - } - - return os.Getenv("APP_ENV") +// Params are the configuration parameters for creating a new Config +type Params struct { + AppEnv string + Port string + WebURL string + DBPath string + DisableRegistration bool + LogLevel string } -func checkDeprecatedEnvVars() { - if os.Getenv("OnPremise") != "" { - - log.WithFields(log.Fields{}).Warn("Environment variable 'OnPremise' is deprecated. Please use OnPremises.") - } -} - -// Load constructs and returns a new config based on the environment variables. -func Load() Config { - port := os.Getenv("PORT") - if port == "" { - port = "3000" - } - - checkDeprecatedEnvVars() - +// New constructs and returns a new validated config. +// Empty string params will fall back to environment variables and defaults. +func New(p Params) (Config, error) { c := Config{ - AppEnv: getAppEnv(), - WebURL: os.Getenv("WebURL"), - Port: port, - OnPremises: readBoolEnv("OnPremise") || readBoolEnv("OnPremises"), - DisableRegistration: readBoolEnv("DisableRegistration"), - DB: loadDBConfig(), - AssetBaseURL: "", + AppEnv: getOrEnv(p.AppEnv, "APP_ENV", AppEnvProduction), + Port: getOrEnv(p.Port, "PORT", "3000"), + WebURL: getOrEnv(p.WebURL, "WebURL", ""), + DBPath: getOrEnv(p.DBPath, "DBPath", DefaultDBPath), + DisableRegistration: p.DisableRegistration || readBoolEnv("DisableRegistration"), + LogLevel: getOrEnv(p.LogLevel, "LOG_LEVEL", "info"), + AssetBaseURL: "/static", HTTP500Page: assets.MustGetHTTP500ErrorPage(), } if err := validate(c); err != nil { - panic(err) + return Config{}, err } - return c -} - -// SetOnPremises sets the OnPremise value -func (c *Config) SetOnPremises(val bool) { - c.OnPremises = val -} - -// SetAssetBaseURL sets static dir for the confi -func (c *Config) SetAssetBaseURL(d string) { - c.AssetBaseURL = d + return c, nil } // IsProd checks if the app environment is configured to be production. @@ -171,31 +116,15 @@ func (c Config) IsProd() bool { func validate(c Config) error { if _, err := url.ParseRequestURI(c.WebURL); err != nil { - return errors.Wrapf(ErrWebURLInvalid, "provided: '%s'", c.WebURL) + return errors.Wrapf(ErrWebURLInvalid, "'%s'", c.WebURL) } if c.Port == "" { return ErrPortInvalid } - if c.DB.Host == "" { - return ErrDBMissingHost - } - if c.DB.Port == "" { - return ErrDBMissingPort - } - if c.DB.Name == "" { - return ErrDBMissingName - } - if c.DB.User == "" { - return ErrDBMissingUser + if c.DBPath == "" { + return ErrDBMissingPath } return nil } - -// GetConnectionStr returns a postgres connection string. -func (c PostgresConfig) GetConnectionStr() string { - return fmt.Sprintf( - "sslmode=%s host=%s port=%s dbname=%s user=%s password=%s", - c.SSLMode, c.Host, c.Port, c.Name, c.User, c.Password) -} diff --git a/pkg/server/config/config_test.go b/pkg/server/config/config_test.go index 9f57b404..802a3add 100644 --- a/pkg/server/config/config_test.go +++ b/pkg/server/config/config_test.go @@ -33,12 +33,7 @@ func TestValidate(t *testing.T) { }{ { config: Config{ - DB: PostgresConfig{ - Host: "mockHost", - Port: "5432", - Name: "mockDB", - User: "mockUser", - }, + DBPath: "test.db", WebURL: "http://mock.url", Port: "3000", }, @@ -46,71 +41,21 @@ func TestValidate(t *testing.T) { }, { config: Config{ - DB: PostgresConfig{ - Port: "5432", - Name: "mockDB", - User: "mockUser", - }, + DBPath: "", WebURL: "http://mock.url", Port: "3000", }, - expectedErr: ErrDBMissingHost, + expectedErr: ErrDBMissingPath, }, { config: Config{ - DB: PostgresConfig{ - Host: "mockHost", - Name: "mockDB", - User: "mockUser", - }, - WebURL: "http://mock.url", - Port: "3000", - }, - expectedErr: ErrDBMissingPort, - }, - { - config: Config{ - DB: PostgresConfig{ - Host: "mockHost", - Port: "5432", - User: "mockUser", - }, - WebURL: "http://mock.url", - Port: "3000", - }, - expectedErr: ErrDBMissingName, - }, - { - config: Config{ - DB: PostgresConfig{ - Host: "mockHost", - Port: "5432", - Name: "mockDB", - }, - WebURL: "http://mock.url", - Port: "3000", - }, - expectedErr: ErrDBMissingUser, - }, - { - config: Config{ - DB: PostgresConfig{ - Host: "mockHost", - Port: "5432", - Name: "mockDB", - User: "mockUser", - }, + DBPath: "test.db", }, expectedErr: ErrWebURLInvalid, }, { config: Config{ - DB: PostgresConfig{ - Host: "mockHost", - Port: "5432", - Name: "mockDB", - User: "mockUser", - }, + DBPath: "test.db", WebURL: "http://mock.url", }, expectedErr: ErrPortInvalid, diff --git a/pkg/server/controllers/books_test.go b/pkg/server/controllers/books_test.go index 6f0e3704..d59dcd17 100644 --- a/pkg/server/controllers/books_test.go +++ b/pkg/server/controllers/books_test.go @@ -21,69 +21,78 @@ package controllers import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "testing" + "time" "github.com/dnote/dnote/pkg/assert" "github.com/dnote/dnote/pkg/clock" "github.com/dnote/dnote/pkg/server/app" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/database" "github.com/dnote/dnote/pkg/server/presenters" "github.com/dnote/dnote/pkg/server/testutils" "github.com/pkg/errors" ) +// truncateMicro rounds time to microsecond precision to match SQLite storage +func truncateMicro(t time.Time) time.Time { + return t.Round(time.Microsecond) +} + func TestGetBooks(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - anotherUser := testutils.SetupUserData() - testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db) + testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", USN: 1123, Deleted: false, } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") b2 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "css", USN: 1125, Deleted: false, } - testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2") + testutils.MustExec(t, db.Save(&b2), "preparing b2") b3 := database.Book{ + UUID: testutils.MustUUID(t), UserID: anotherUser.ID, Label: "css", USN: 1128, Deleted: false, } - testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3") + testutils.MustExec(t, db.Save(&b3), "preparing b3") b4 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "", USN: 1129, Deleted: true, } - testutils.MustExec(t, testutils.DB.Save(&b4), "preparing b4") + testutils.MustExec(t, db.Save(&b4), "preparing b4") // Execute endpoint := "/api/v3/books" req := testutils.MakeReq(server.URL, "GET", endpoint, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -94,66 +103,75 @@ func TestGetBooks(t *testing.T) { } var b1Record, b2Record database.Book - testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1") - testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2") - testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2") + testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1") + testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2") + testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2") expected := []presenters.Book{ { UUID: b2Record.UUID, - CreatedAt: b2Record.CreatedAt, - UpdatedAt: b2Record.UpdatedAt, + CreatedAt: truncateMicro(b2Record.CreatedAt), + UpdatedAt: truncateMicro(b2Record.UpdatedAt), Label: b2Record.Label, USN: b2Record.USN, }, { UUID: b1Record.UUID, - CreatedAt: b1Record.CreatedAt, - UpdatedAt: b1Record.UpdatedAt, + CreatedAt: truncateMicro(b1Record.CreatedAt), + UpdatedAt: truncateMicro(b1Record.UpdatedAt), Label: b1Record.Label, USN: b1Record.USN, }, } + // Truncate payload timestamps to match SQLite precision + for i := range payload { + payload[i].CreatedAt = truncateMicro(payload[i].CreatedAt) + payload[i].UpdatedAt = truncateMicro(payload[i].UpdatedAt) + } + assert.DeepEqual(t, payload, expected, "payload mismatch") } func TestGetBooksByName(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - anotherUser := testutils.SetupUserData() - testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db) + testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") b2 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "css", } - testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2") + testutils.MustExec(t, db.Save(&b2), "preparing b2") b3 := database.Book{ + UUID: testutils.MustUUID(t), UserID: anotherUser.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3") + testutils.MustExec(t, db.Save(&b3), "preparing b3") // Execute endpoint := "/api/v3/books?name=js" req := testutils.MakeReq(server.URL, "GET", endpoint, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -164,56 +182,64 @@ func TestGetBooksByName(t *testing.T) { } var b1Record database.Book - testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1") + testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1") expected := []presenters.Book{ { UUID: b1Record.UUID, - CreatedAt: b1Record.CreatedAt, - UpdatedAt: b1Record.UpdatedAt, + CreatedAt: truncateMicro(b1Record.CreatedAt), + UpdatedAt: truncateMicro(b1Record.UpdatedAt), Label: b1Record.Label, USN: b1Record.USN, }, } + for i := range payload { + payload[i].CreatedAt = truncateMicro(payload[i].CreatedAt) + payload[i].UpdatedAt = truncateMicro(payload[i].UpdatedAt) + } + assert.DeepEqual(t, payload, expected, "payload mismatch") } func TestGetBook(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - anotherUser := testutils.SetupUserData() - testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db) + testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") b2 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "css", } - testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2") + testutils.MustExec(t, db.Save(&b2), "preparing b2") b3 := database.Book{ + UUID: testutils.MustUUID(t), UserID: anotherUser.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3") + testutils.MustExec(t, db.Save(&b3), "preparing b3") // Execute endpoint := fmt.Sprintf("/api/v3/books/%s", b1.UUID) req := testutils.MakeReq(server.URL, "GET", endpoint, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -224,49 +250,53 @@ func TestGetBook(t *testing.T) { } var b1Record database.Book - testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1") + testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1") expected := presenters.Book{ UUID: b1Record.UUID, - CreatedAt: b1Record.CreatedAt, - UpdatedAt: b1Record.UpdatedAt, + CreatedAt: truncateMicro(b1Record.CreatedAt), + UpdatedAt: truncateMicro(b1Record.UpdatedAt), Label: b1Record.Label, USN: b1Record.USN, } + payload.CreatedAt = truncateMicro(payload.CreatedAt) + payload.UpdatedAt = truncateMicro(payload.UpdatedAt) + assert.DeepEqual(t, payload, expected, "payload mismatch") } func TestGetBookNonOwner(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - nonOwner := testutils.SetupUserData() - testutils.SetupAccountData(nonOwner, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + nonOwner := testutils.SetupUserData(db) + testutils.SetupAccountData(db, nonOwner, "bob@test.com", "pass1234") b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") // Execute endpoint := fmt.Sprintf("/api/v3/books/%s", b1.UUID) req := testutils.MakeReq(server.URL, "GET", endpoint, "") - res := testutils.HTTPAuthDo(t, req, nonOwner) + res := testutils.HTTPAuthDo(t, db, req, nonOwner) // Test assert.StatusCodeEquals(t, res, http.StatusNotFound, "") - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(errors.Wrap(err, "reading body")) } @@ -275,23 +305,23 @@ func TestGetBookNonOwner(t *testing.T) { func TestCreateBook(t *testing.T) { t.Run("success", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`) // Execute - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusCreated, "") @@ -299,10 +329,10 @@ func TestCreateBook(t *testing.T) { var bookRecord database.Book var userRecord database.User var bookCount, noteCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books") - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes") - testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book") - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record") + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books") + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes") + testutils.MustExec(t, db.First(&bookRecord), "finding book") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record") maxUSN := 102 @@ -323,39 +353,43 @@ func TestCreateBook(t *testing.T) { Book: presenters.Book{ UUID: bookRecord.UUID, USN: bookRecord.USN, - CreatedAt: bookRecord.CreatedAt, - UpdatedAt: bookRecord.UpdatedAt, + CreatedAt: truncateMicro(bookRecord.CreatedAt), + UpdatedAt: truncateMicro(bookRecord.UpdatedAt), Label: "js", }, } + got.Book.CreatedAt = truncateMicro(got.Book.CreatedAt) + got.Book.UpdatedAt = truncateMicro(got.Book.UpdatedAt) + assert.DeepEqual(t, got, expected, "payload mismatch") }) t.Run("duplicate", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", USN: 58, } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing book data") + testutils.MustExec(t, db.Save(&b1), "preparing book data") // Execute req := testutils.MakeReq(server.URL, "POST", "/api/v3/books", `{"name": "js"}`) - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusConflict, "") @@ -363,10 +397,10 @@ func TestCreateBook(t *testing.T) { var bookRecord database.Book var bookCount, noteCount int64 var userRecord database.User - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books") - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes") - testutils.MustExec(t, testutils.DB.First(&bookRecord), "finding book") - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record") + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books") + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes") + testutils.MustExec(t, db.First(&bookRecord), "finding book") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record") assert.Equalf(t, bookCount, int64(1), "book count mismatch") assert.Equalf(t, noteCount, int64(0), "note count mismatch") @@ -422,18 +456,18 @@ func TestUpdateBook(t *testing.T) { for idx, tc := range testCases { t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") b1 := database.Book{ UUID: tc.bookUUID, @@ -441,18 +475,18 @@ func TestUpdateBook(t *testing.T) { Label: tc.bookLabel, Deleted: tc.bookDeleted, } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") b2 := database.Book{ UUID: b2UUID, UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2") + testutils.MustExec(t, db.Save(&b2), "preparing b2") // Execute endpoint := fmt.Sprintf("/api/v3/books/%s", tc.bookUUID) req := testutils.MakeReq(server.URL, "PATCH", endpoint, tc.payload.ToJSON(t)) - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, fmt.Sprintf("status code mismatch for test case %d", idx)) @@ -460,10 +494,10 @@ func TestUpdateBook(t *testing.T) { var bookRecord database.Book var userRecord database.User var noteCount, bookCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books") - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes") - testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book") - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record") + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books") + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes") + testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record") assert.Equalf(t, bookCount, int64(2), "book count mismatch") assert.Equalf(t, noteCount, int64(0), "note count mismatch") @@ -507,41 +541,44 @@ func TestDeleteBook(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("originally deleted %t", tc.deleted), func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 58), "preparing user max_usn") - anotherUser := testutils.SetupUserData() - testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234") - testutils.MustExec(t, testutils.DB.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + testutils.MustExec(t, db.Model(&user).Update("max_usn", 58), "preparing user max_usn") + anotherUser := testutils.SetupUserData(db) + testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") + testutils.MustExec(t, db.Model(&anotherUser).Update("max_usn", 109), "preparing another user max_usn") b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", USN: 1, } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing a book data") + testutils.MustExec(t, db.Save(&b1), "preparing a book data") b2 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: tc.label, USN: 2, Deleted: tc.deleted, } - testutils.MustExec(t, testutils.DB.Save(&b2), "preparing a book data") + testutils.MustExec(t, db.Save(&b2), "preparing a book data") b3 := database.Book{ + UUID: testutils.MustUUID(t), UserID: anotherUser.ID, Label: "linux", USN: 3, } - testutils.MustExec(t, testutils.DB.Save(&b3), "preparing a book data") + testutils.MustExec(t, db.Save(&b3), "preparing a book data") var n2Body string if !tc.deleted { @@ -553,49 +590,54 @@ func TestDeleteBook(t *testing.T) { } n1 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "n1 content", USN: 4, } - testutils.MustExec(t, testutils.DB.Save(&n1), "preparing a note data") + testutils.MustExec(t, db.Save(&n1), "preparing a note data") n2 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b2.UUID, Body: n2Body, USN: 5, Deleted: tc.deleted, } - testutils.MustExec(t, testutils.DB.Save(&n2), "preparing a note data") + testutils.MustExec(t, db.Save(&n2), "preparing a note data") n3 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b2.UUID, Body: n3Body, USN: 6, Deleted: tc.deleted, } - testutils.MustExec(t, testutils.DB.Save(&n3), "preparing a note data") + testutils.MustExec(t, db.Save(&n3), "preparing a note data") n4 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b2.UUID, Body: "", USN: 7, Deleted: true, } - testutils.MustExec(t, testutils.DB.Save(&n4), "preparing a note data") + testutils.MustExec(t, db.Save(&n4), "preparing a note data") n5 := database.Note{ + UUID: testutils.MustUUID(t), UserID: anotherUser.ID, BookUUID: b3.UUID, Body: "n5 content", USN: 8, } - testutils.MustExec(t, testutils.DB.Save(&n5), "preparing a note data") + testutils.MustExec(t, db.Save(&n5), "preparing a note data") // Execute endpoint := fmt.Sprintf("/api/v3/books/%s", b2.UUID) req := testutils.MakeReq(server.URL, "DELETE", endpoint, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -605,17 +647,17 @@ func TestDeleteBook(t *testing.T) { var userRecord database.User var bookCount, noteCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books") - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes") - testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&b1Record), "finding b1") - testutils.MustExec(t, testutils.DB.Where("id = ?", b2.ID).First(&b2Record), "finding b2") - testutils.MustExec(t, testutils.DB.Where("id = ?", b3.ID).First(&b3Record), "finding b3") - testutils.MustExec(t, testutils.DB.Where("id = ?", n1.ID).First(&n1Record), "finding n1") - testutils.MustExec(t, testutils.DB.Where("id = ?", n2.ID).First(&n2Record), "finding n2") - testutils.MustExec(t, testutils.DB.Where("id = ?", n3.ID).First(&n3Record), "finding n3") - testutils.MustExec(t, testutils.DB.Where("id = ?", n4.ID).First(&n4Record), "finding n4") - testutils.MustExec(t, testutils.DB.Where("id = ?", n5.ID).First(&n5Record), "finding n5") - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record") + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books") + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes") + testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&b1Record), "finding b1") + testutils.MustExec(t, db.Where("id = ?", b2.ID).First(&b2Record), "finding b2") + testutils.MustExec(t, db.Where("id = ?", b3.ID).First(&b3Record), "finding b3") + testutils.MustExec(t, db.Where("id = ?", n1.ID).First(&n1Record), "finding n1") + testutils.MustExec(t, db.Where("id = ?", n2.ID).First(&n2Record), "finding n2") + testutils.MustExec(t, db.Where("id = ?", n3.ID).First(&n3Record), "finding n3") + testutils.MustExec(t, db.Where("id = ?", n4.ID).First(&n4Record), "finding n4") + testutils.MustExec(t, db.Where("id = ?", n5.ID).First(&n5Record), "finding n5") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record") assert.Equal(t, bookCount, int64(3), "book count mismatch") assert.Equal(t, noteCount, int64(5), "note count mismatch") diff --git a/pkg/server/controllers/health_test.go b/pkg/server/controllers/health_test.go index a1d0294e..cec8d35a 100644 --- a/pkg/server/controllers/health_test.go +++ b/pkg/server/controllers/health_test.go @@ -24,16 +24,15 @@ import ( "github.com/dnote/dnote/pkg/assert" "github.com/dnote/dnote/pkg/server/app" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/testutils" ) func TestHealth(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - server := MustNewServer(t, &app.App{ - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() // Execute diff --git a/pkg/server/controllers/helpers.go b/pkg/server/controllers/helpers.go index ba697be0..0e88034b 100644 --- a/pkg/server/controllers/helpers.go +++ b/pkg/server/controllers/helpers.go @@ -61,13 +61,6 @@ func parseForm(r *http.Request, dst interface{}) error { return parseValues(r.PostForm, dst) } -func parseURLParams(r *http.Request, dst interface{}) error { - if err := r.ParseForm(); err != nil { - return err - } - return parseValues(r.Form, dst) -} - func parseValues(values url.Values, dst interface{}) error { dec := schema.NewDecoder() diff --git a/pkg/server/controllers/main_test.go b/pkg/server/controllers/main_test.go index 7e48d502..35eef6da 100644 --- a/pkg/server/controllers/main_test.go +++ b/pkg/server/controllers/main_test.go @@ -21,15 +21,13 @@ package controllers import ( "os" "testing" - - "github.com/dnote/dnote/pkg/server/testutils" + "time" ) func TestMain(m *testing.M) { - testutils.InitTestDB() + // Set timezone to UTC to match database timestamps + time.Local = time.UTC code := m.Run() - testutils.ClearData(testutils.DB) - os.Exit(code) } diff --git a/pkg/server/controllers/notes.go b/pkg/server/controllers/notes.go index 0a12c071..a7434366 100644 --- a/pkg/server/controllers/notes.go +++ b/pkg/server/controllers/notes.go @@ -19,13 +19,10 @@ package controllers import ( - "math" "net/http" "net/url" - "sort" "strconv" "strings" - "time" "github.com/dnote/dnote/pkg/server/app" "github.com/dnote/dnote/pkg/server/context" @@ -150,69 +147,6 @@ func (n *Notes) getNotes(r *http.Request) (app.GetNotesResult, app.GetNotesParam return res, p, nil } -type noteGroup struct { - Year int - Month int - Data []database.Note -} - -type bucketKey struct { - year int - month time.Month -} - -func groupNotes(notes []database.Note) []noteGroup { - ret := []noteGroup{} - - buckets := map[bucketKey][]database.Note{} - - for _, note := range notes { - year := note.UpdatedAt.Year() - month := note.UpdatedAt.Month() - key := bucketKey{year, month} - - if _, ok := buckets[key]; !ok { - buckets[key] = []database.Note{} - } - - buckets[key] = append(buckets[key], note) - } - - keys := []bucketKey{} - for key := range buckets { - keys = append(keys, key) - } - - sort.Slice(keys, func(i, j int) bool { - yearI := keys[i].year - yearJ := keys[j].year - monthI := keys[i].month - monthJ := keys[j].month - - if yearI == yearJ { - return monthI < monthJ - } - - return yearI < yearJ - }) - - for _, key := range keys { - group := noteGroup{ - Year: key.year, - Month: int(key.month), - Data: buckets[key], - } - ret = append(ret, group) - } - - return ret -} - -func getMaxPage(page, total int) int { - tmp := float64(total) / float64(notesPerPage) - return int(math.Ceil(tmp)) -} - // GetNotesResponse is a reponse by getNotesHandler type GetNotesResponse struct { Notes []presenters.Note `json:"notes"` diff --git a/pkg/server/controllers/notes_test.go b/pkg/server/controllers/notes_test.go index caeaa775..6bf3c06b 100644 --- a/pkg/server/controllers/notes_test.go +++ b/pkg/server/controllers/notes_test.go @@ -21,7 +21,7 @@ package controllers import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "testing" "time" @@ -29,7 +29,6 @@ import ( "github.com/dnote/dnote/pkg/assert" "github.com/dnote/dnote/pkg/clock" "github.com/dnote/dnote/pkg/server/app" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/database" "github.com/dnote/dnote/pkg/server/presenters" "github.com/dnote/dnote/pkg/server/testutils" @@ -39,8 +38,8 @@ import ( func getExpectedNotePayload(n database.Note, b database.Book, u database.User) presenters.Note { return presenters.Note{ UUID: n.UUID, - CreatedAt: n.CreatedAt, - UpdatedAt: n.UpdatedAt, + CreatedAt: truncateMicro(n.CreatedAt), + UpdatedAt: truncateMicro(n.UpdatedAt), Body: n.Body, AddedOn: n.AddedOn, Public: n.Public, @@ -56,37 +55,41 @@ func getExpectedNotePayload(n database.Note, b database.Book, u database.User) p } func TestGetNotes(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - anotherUser := testutils.SetupUserData() - testutils.SetupAccountData(anotherUser, "bob@test.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + anotherUser := testutils.SetupUserData(db) + testutils.SetupAccountData(db, anotherUser, "bob@test.com", "pass1234") b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") b2 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "css", } - testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2") + testutils.MustExec(t, db.Save(&b2), "preparing b2") b3 := database.Book{ + UUID: testutils.MustUUID(t), UserID: anotherUser.ID, Label: "css", } - testutils.MustExec(t, testutils.DB.Save(&b3), "preparing b3") + testutils.MustExec(t, db.Save(&b3), "preparing b3") n1 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "n1 content", @@ -94,8 +97,9 @@ func TestGetNotes(t *testing.T) { Deleted: false, AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(), } - testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1") + testutils.MustExec(t, db.Save(&n1), "preparing n1") n2 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "n2 content", @@ -103,8 +107,9 @@ func TestGetNotes(t *testing.T) { Deleted: false, AddedOn: time.Date(2018, time.August, 11, 22, 0, 0, 0, time.UTC).UnixNano(), } - testutils.MustExec(t, testutils.DB.Save(&n2), "preparing n2") + testutils.MustExec(t, db.Save(&n2), "preparing n2") n3 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "n3 content", @@ -112,8 +117,9 @@ func TestGetNotes(t *testing.T) { Deleted: false, AddedOn: time.Date(2017, time.January, 10, 23, 0, 0, 0, time.UTC).UnixNano(), } - testutils.MustExec(t, testutils.DB.Save(&n3), "preparing n3") + testutils.MustExec(t, db.Save(&n3), "preparing n3") n4 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b2.UUID, Body: "n4 content", @@ -121,8 +127,9 @@ func TestGetNotes(t *testing.T) { Deleted: false, AddedOn: time.Date(2018, time.September, 10, 23, 0, 0, 0, time.UTC).UnixNano(), } - testutils.MustExec(t, testutils.DB.Save(&n4), "preparing n4") + testutils.MustExec(t, db.Save(&n4), "preparing n4") n5 := database.Note{ + UUID: testutils.MustUUID(t), UserID: anotherUser.ID, BookUUID: b3.UUID, Body: "n5 content", @@ -130,8 +137,9 @@ func TestGetNotes(t *testing.T) { Deleted: false, AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(), } - testutils.MustExec(t, testutils.DB.Save(&n5), "preparing n5") + testutils.MustExec(t, db.Save(&n5), "preparing n5") n6 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "", @@ -139,13 +147,13 @@ func TestGetNotes(t *testing.T) { Deleted: true, AddedOn: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC).UnixNano(), } - testutils.MustExec(t, testutils.DB.Save(&n6), "preparing n6") + testutils.MustExec(t, db.Save(&n6), "preparing n6") // Execute endpoint := "/api/v3/notes" req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("%s?year=2018&month=8", endpoint), "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -156,8 +164,8 @@ func TestGetNotes(t *testing.T) { } var n2Record, n1Record database.Note - testutils.MustExec(t, testutils.DB.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record") - testutils.MustExec(t, testutils.DB.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record") + testutils.MustExec(t, db.Where("uuid = ?", n2.UUID).First(&n2Record), "finding n2Record") + testutils.MustExec(t, db.Where("uuid = ?", n1.UUID).First(&n1Record), "finding n1Record") expected := GetNotesResponse{ Notes: []presenters.Note{ @@ -171,44 +179,48 @@ func TestGetNotes(t *testing.T) { } func TestGetNote(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - anotherUser := testutils.SetupUserData() + user := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db) b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") privateNote := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "privateNote content", Public: false, } - testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote") + testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote") publicNote := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "publicNote content", Public: true, } - testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing publicNote") + testutils.MustExec(t, db.Save(&publicNote), "preparing publicNote") deletedNote := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Deleted: true, } - testutils.MustExec(t, testutils.DB.Save(&deletedNote), "preparing deletedNote") + testutils.MustExec(t, db.Save(&deletedNote), "preparing deletedNote") getURL := func(noteUUID string) string { return fmt.Sprintf("/api/v3/notes/%s", noteUUID) @@ -218,7 +230,7 @@ func TestGetNote(t *testing.T) { // Execute url := getURL(publicNote.UUID) req := testutils.MakeReq(server.URL, "GET", url, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -229,7 +241,7 @@ func TestGetNote(t *testing.T) { } var n2Record database.Note - testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record") + testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record") expected := getExpectedNotePayload(n2Record, b1, user) assert.DeepEqual(t, payload, expected, "payload mismatch") @@ -239,7 +251,7 @@ func TestGetNote(t *testing.T) { // Execute url := getURL(publicNote.UUID) req := testutils.MakeReq(server.URL, "GET", url, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -250,7 +262,7 @@ func TestGetNote(t *testing.T) { } var n2Record database.Note - testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record") + testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record") expected := getExpectedNotePayload(n2Record, b1, user) assert.DeepEqual(t, payload, expected, "payload mismatch") @@ -260,7 +272,7 @@ func TestGetNote(t *testing.T) { // Execute url := getURL(publicNote.UUID) req := testutils.MakeReq(server.URL, "GET", url, "") - res := testutils.HTTPAuthDo(t, req, anotherUser) + res := testutils.HTTPAuthDo(t, db, req, anotherUser) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -271,7 +283,7 @@ func TestGetNote(t *testing.T) { } var n2Record database.Note - testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record") + testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record") expected := getExpectedNotePayload(n2Record, b1, user) assert.DeepEqual(t, payload, expected, "payload mismatch") @@ -281,12 +293,12 @@ func TestGetNote(t *testing.T) { // Execute url := getURL(privateNote.UUID) req := testutils.MakeReq(server.URL, "GET", url, "") - res := testutils.HTTPAuthDo(t, req, anotherUser) + res := testutils.HTTPAuthDo(t, db, req, anotherUser) // Test assert.StatusCodeEquals(t, res, http.StatusNotFound, "") - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(errors.Wrap(err, "reading body")) } @@ -309,7 +321,7 @@ func TestGetNote(t *testing.T) { } var n2Record database.Note - testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record") + testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).First(&n2Record), "finding n2Record") expected := getExpectedNotePayload(n2Record, b1, user) assert.DeepEqual(t, payload, expected, "payload mismatch") @@ -324,7 +336,7 @@ func TestGetNote(t *testing.T) { // Test assert.StatusCodeEquals(t, res, http.StatusNotFound, "") - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(errors.Wrap(err, "reading body")) } @@ -336,12 +348,12 @@ func TestGetNote(t *testing.T) { // Execute url := getURL("somerandomstring") req := testutils.MakeReq(server.URL, "GET", url, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusNotFound, "") - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(errors.Wrap(err, "reading body")) } @@ -353,12 +365,12 @@ func TestGetNote(t *testing.T) { // Execute url := getURL(deletedNote.UUID) req := testutils.MakeReq(server.URL, "GET", url, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusNotFound, "") - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(errors.Wrap(err, "reading body")) } @@ -368,31 +380,32 @@ func TestGetNote(t *testing.T) { } func TestCreateNote(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", USN: 58, } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") // Execute dat := fmt.Sprintf(`{"book_uuid": "%s", "content": "note content"}`, b1.UUID) req := testutils.MakeReq(server.URL, "POST", "/api/v3/notes", dat) - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusCreated, "") @@ -401,11 +414,11 @@ func TestCreateNote(t *testing.T) { var bookRecord database.Book var userRecord database.User var bookCount, noteCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books") - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes") - testutils.MustExec(t, testutils.DB.First(¬eRecord), "finding note") - testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book") - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record") + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books") + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes") + testutils.MustExec(t, db.First(¬eRecord), "finding note") + testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record") assert.Equalf(t, bookCount, int64(1), "book count mismatch") assert.Equalf(t, noteCount, int64(1), "note count mismatch") @@ -449,38 +462,39 @@ func TestDeleteNote(t *testing.T) { for idx, tc := range testCases { t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 981), "preparing user max_usn") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + testutils.MustExec(t, db.Model(&user).Update("max_usn", 981), "preparing user max_usn") b1 := database.Book{ UUID: b1UUID, UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") note := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: tc.content, Deleted: tc.deleted, USN: tc.originalUSN, } - testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note") + testutils.MustExec(t, db.Save(¬e), "preparing note") // Execute endpoint := fmt.Sprintf("/api/v3/notes/%s", note.UUID) req := testutils.MakeReq(server.URL, "DELETE", endpoint, "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") @@ -489,11 +503,11 @@ func TestDeleteNote(t *testing.T) { var noteRecord database.Note var userRecord database.User var bookCount, noteCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books") - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes") - testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note") - testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book") - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record") + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books") + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes") + testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note") + testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record") assert.Equalf(t, bookCount, int64(1), "book count mismatch") assert.Equalf(t, noteCount, int64(1), "note count mismatch") @@ -687,42 +701,42 @@ func TestUpdateNote(t *testing.T) { for idx, tc := range testCases { t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.DB = db + a.Clock = clock.NewMock() + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") - testutils.MustExec(t, testutils.DB.Model(&user).Update("max_usn", 101), "preparing user max_usn") + testutils.MustExec(t, db.Model(&user).Update("max_usn", 101), "preparing user max_usn") b1 := database.Book{ UUID: b1UUID, UserID: user.ID, Label: "css", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") b2 := database.Book{ UUID: b2UUID, UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b2), "preparing b2") + testutils.MustExec(t, db.Save(&b2), "preparing b2") note := database.Note{ - UserID: user.ID, UUID: tc.noteUUID, + UserID: user.ID, BookUUID: tc.noteBookUUID, Body: tc.noteBody, Deleted: tc.noteDeleted, Public: tc.notePublic, } - testutils.MustExec(t, testutils.DB.Save(¬e), "preparing note") + testutils.MustExec(t, db.Save(¬e), "preparing note") // Execute var req *http.Request @@ -730,7 +744,7 @@ func TestUpdateNote(t *testing.T) { endpoint := fmt.Sprintf("/api/v3/notes/%s", note.UUID) req = testutils.MakeReq(server.URL, "PATCH", endpoint, tc.payload.ToJSON(t)) - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "status code mismatch for test case") @@ -739,11 +753,11 @@ func TestUpdateNote(t *testing.T) { var noteRecord database.Note var userRecord database.User var noteCount, bookCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Book{}).Count(&bookCount), "counting books") - testutils.MustExec(t, testutils.DB.Model(&database.Note{}).Count(¬eCount), "counting notes") - testutils.MustExec(t, testutils.DB.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note") - testutils.MustExec(t, testutils.DB.Where("id = ?", b1.ID).First(&bookRecord), "finding book") - testutils.MustExec(t, testutils.DB.Where("id = ?", user.ID).First(&userRecord), "finding user record") + testutils.MustExec(t, db.Model(&database.Book{}).Count(&bookCount), "counting books") + testutils.MustExec(t, db.Model(&database.Note{}).Count(¬eCount), "counting notes") + testutils.MustExec(t, db.Where("uuid = ?", note.UUID).First(¬eRecord), "finding note") + testutils.MustExec(t, db.Where("id = ?", b1.ID).First(&bookRecord), "finding book") + testutils.MustExec(t, db.Where("id = ?", user.ID).First(&userRecord), "finding user record") assert.Equalf(t, bookCount, int64(2), "book count mismatch") assert.Equalf(t, noteCount, int64(1), "note count mismatch") diff --git a/pkg/server/controllers/routes.go b/pkg/server/controllers/routes.go index 3e80a5ed..0c873fa3 100644 --- a/pkg/server/controllers/routes.go +++ b/pkg/server/controllers/routes.go @@ -48,25 +48,25 @@ func NewWebRoutes(a *app.App, c *Controllers) []Route { redirectGuest := &mw.AuthParams{RedirectGuestsToLogin: true} ret := []Route{ - {"GET", "/", mw.Auth(a, c.Users.Settings, redirectGuest), true}, - {"GET", "/about", mw.Auth(a, c.Users.About, redirectGuest), true}, - {"GET", "/login", mw.GuestOnly(a, c.Users.NewLogin), true}, - {"POST", "/login", mw.GuestOnly(a, c.Users.Login), true}, + {"GET", "/", mw.Auth(a.DB, c.Users.Settings, redirectGuest), true}, + {"GET", "/about", mw.Auth(a.DB, c.Users.About, redirectGuest), true}, + {"GET", "/login", mw.GuestOnly(a.DB, c.Users.NewLogin), true}, + {"POST", "/login", mw.GuestOnly(a.DB, c.Users.Login), true}, {"POST", "/logout", c.Users.Logout, true}, {"GET", "/password-reset", c.Users.PasswordResetView.ServeHTTP, true}, {"PATCH", "/password-reset", c.Users.PasswordReset, true}, {"GET", "/password-reset/{token}", c.Users.PasswordResetConfirm, true}, {"POST", "/reset-token", c.Users.CreateResetToken, true}, - {"POST", "/verification-token", mw.Auth(a, c.Users.CreateEmailVerificationToken, redirectGuest), true}, - {"GET", "/verify-email/{token}", mw.Auth(a, c.Users.VerifyEmail, redirectGuest), true}, - {"PATCH", "/account/profile", mw.Auth(a, c.Users.ProfileUpdate, nil), true}, - {"PATCH", "/account/password", mw.Auth(a, c.Users.PasswordUpdate, nil), true}, + {"POST", "/verification-token", mw.Auth(a.DB, c.Users.CreateEmailVerificationToken, redirectGuest), true}, + {"GET", "/verify-email/{token}", mw.Auth(a.DB, c.Users.VerifyEmail, redirectGuest), true}, + {"PATCH", "/account/profile", mw.Auth(a.DB, c.Users.ProfileUpdate, nil), true}, + {"PATCH", "/account/password", mw.Auth(a.DB, c.Users.PasswordUpdate, nil), true}, {"GET", "/health", c.Health.Index, true}, } - if !a.Config.DisableRegistration { + if !a.DisableRegistration { ret = append(ret, Route{"GET", "/join", c.Users.New, true}) ret = append(ret, Route{"POST", "/join", c.Users.Create, true}) } @@ -76,28 +76,25 @@ func NewWebRoutes(a *app.App, c *Controllers) []Route { // NewAPIRoutes returns a new api routes func NewAPIRoutes(a *app.App, c *Controllers) []Route { - - proOnly := mw.AuthParams{ProOnly: true} - return []Route{ // v3 - {"GET", "/v3/sync/fragment", mw.Cors(mw.Auth(a, c.Sync.GetSyncFragment, &proOnly)), false}, - {"GET", "/v3/sync/state", mw.Cors(mw.Auth(a, c.Sync.GetSyncState, &proOnly)), false}, - {"POST", "/v3/signin", mw.Cors(c.Users.V3Login), true}, - {"POST", "/v3/signout", mw.Cors(c.Users.V3Logout), true}, - {"OPTIONS", "/v3/signout", mw.Cors(c.Users.logoutOptions), true}, - {"GET", "/v3/notes", mw.Cors(mw.Auth(a, c.Notes.V3Index, nil)), true}, + {"GET", "/v3/sync/fragment", mw.Auth(a.DB, c.Sync.GetSyncFragment, nil), false}, + {"GET", "/v3/sync/state", mw.Auth(a.DB, c.Sync.GetSyncState, nil), false}, + {"POST", "/v3/signin", c.Users.V3Login, true}, + {"POST", "/v3/signout", c.Users.V3Logout, true}, + {"OPTIONS", "/v3/signout", c.Users.logoutOptions, true}, + {"GET", "/v3/notes", mw.Auth(a.DB, c.Notes.V3Index, nil), true}, {"GET", "/v3/notes/{noteUUID}", c.Notes.V3Show, true}, - {"POST", "/v3/notes", mw.Cors(mw.Auth(a, c.Notes.V3Create, nil)), true}, - {"DELETE", "/v3/notes/{noteUUID}", mw.Cors(mw.Auth(a, c.Notes.V3Delete, nil)), true}, - {"PATCH", "/v3/notes/{noteUUID}", mw.Cors(mw.Auth(a, c.Notes.V3Update, nil)), true}, - {"OPTIONS", "/v3/notes", mw.Cors(c.Notes.IndexOptions), true}, - {"GET", "/v3/books", mw.Cors(mw.Auth(a, c.Books.V3Index, nil)), true}, - {"GET", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Show, nil)), true}, - {"POST", "/v3/books", mw.Cors(mw.Auth(a, c.Books.V3Create, nil)), true}, - {"PATCH", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Update, nil)), true}, - {"DELETE", "/v3/books/{bookUUID}", mw.Cors(mw.Auth(a, c.Books.V3Delete, nil)), true}, - {"OPTIONS", "/v3/books", mw.Cors(c.Books.IndexOptions), true}, + {"POST", "/v3/notes", mw.Auth(a.DB, c.Notes.V3Create, nil), true}, + {"DELETE", "/v3/notes/{noteUUID}", mw.Auth(a.DB, c.Notes.V3Delete, nil), true}, + {"PATCH", "/v3/notes/{noteUUID}", mw.Auth(a.DB, c.Notes.V3Update, nil), true}, + {"OPTIONS", "/v3/notes", c.Notes.IndexOptions, true}, + {"GET", "/v3/books", mw.Auth(a.DB, c.Books.V3Index, nil), true}, + {"GET", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Show, nil), true}, + {"POST", "/v3/books", mw.Auth(a.DB, c.Books.V3Create, nil), true}, + {"PATCH", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Update, nil), true}, + {"DELETE", "/v3/books/{bookUUID}", mw.Auth(a.DB, c.Books.V3Delete, nil), true}, + {"OPTIONS", "/v3/books", c.Books.IndexOptions, true}, } } diff --git a/pkg/server/controllers/routes_test.go b/pkg/server/controllers/routes_test.go index 39b151cf..d3084aa3 100644 --- a/pkg/server/controllers/routes_test.go +++ b/pkg/server/controllers/routes_test.go @@ -25,7 +25,6 @@ import ( "github.com/dnote/dnote/pkg/assert" "github.com/dnote/dnote/pkg/clock" "github.com/dnote/dnote/pkg/server/app" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/testutils" ) @@ -56,10 +55,11 @@ func TestNotSupportedVersions(t *testing.T) { } // setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + db := testutils.InitMemoryDB(t) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() for _, tc := range testCases { diff --git a/pkg/server/controllers/testutils.go b/pkg/server/controllers/testutils.go index 807f92d8..03a097c8 100644 --- a/pkg/server/controllers/testutils.go +++ b/pkg/server/controllers/testutils.go @@ -27,9 +27,9 @@ import ( ) // MustNewServer is a test utility function to initialize a new server -// with the given app paratmers -func MustNewServer(t *testing.T, appParams *app.App) *httptest.Server { - server, err := NewServer(appParams) +// with the given app +func MustNewServer(t *testing.T, a *app.App) *httptest.Server { + server, err := NewServer(a) if err != nil { t.Fatal(errors.Wrap(err, "initializing router")) } @@ -37,16 +37,14 @@ func MustNewServer(t *testing.T, appParams *app.App) *httptest.Server { return server } -func NewServer(appParams *app.App) (*httptest.Server, error) { - a := app.NewTest(appParams) - - ctl := New(&a) +func NewServer(a *app.App) (*httptest.Server, error) { + ctl := New(a) rc := RouteConfig{ - WebRoutes: NewWebRoutes(&a, ctl), - APIRoutes: NewAPIRoutes(&a, ctl), + WebRoutes: NewWebRoutes(a, ctl), + APIRoutes: NewAPIRoutes(a, ctl), Controllers: ctl, } - r, err := NewRouter(&a, rc) + r, err := NewRouter(a, rc) if err != nil { return nil, errors.Wrap(err, "initializing router") } diff --git a/pkg/server/controllers/users_test.go b/pkg/server/controllers/users_test.go index 4546a18a..e4c9ff2d 100644 --- a/pkg/server/controllers/users_test.go +++ b/pkg/server/controllers/users_test.go @@ -30,18 +30,18 @@ import ( "github.com/dnote/dnote/pkg/assert" "github.com/dnote/dnote/pkg/clock" "github.com/dnote/dnote/pkg/server/app" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/database" "github.com/dnote/dnote/pkg/server/testutils" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" ) -func assertResponseSessionCookie(t *testing.T, res *http.Response) { +func assertResponseSessionCookie(t *testing.T, db *gorm.DB, res *http.Response) { var sessionCount int64 var session database.Session - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session") - testutils.MustExec(t, testutils.DB.First(&session), "getting session") + testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session") + testutils.MustExec(t, db.First(&session), "getting session") c := testutils.GetCookieByName(res.Cookies(), "id") assert.Equal(t, c.Value, session.Key, "session key mismatch") @@ -55,53 +55,35 @@ func TestJoin(t *testing.T) { email string password string passwordConfirmation string - onPremises bool - expectedPro bool }{ { email: "alice@example.com", password: "pass1234", passwordConfirmation: "pass1234", - onPremises: false, - expectedPro: false, }, { email: "bob@example.com", password: "Y9EwmjH@Jq6y5a64MSACUoM4w7SAhzvY", passwordConfirmation: "Y9EwmjH@Jq6y5a64MSACUoM4w7SAhzvY", - onPremises: false, - expectedPro: false, }, { email: "chuck@example.com", password: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC", passwordConfirmation: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC", - onPremises: false, - expectedPro: false, - }, - // on premise - { - email: "dan@example.com", - password: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC", - passwordConfirmation: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC", - onPremises: true, - expectedPro: true, }, } for _, tc := range testCases { t.Run(fmt.Sprintf("register %s %s", tc.email, tc.password), func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup emailBackend := testutils.MockEmailbackendImplementation{} - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - EmailBackend: &emailBackend, - Config: config.Config{ - OnPremises: tc.onPremises, - }, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.EmailBackend = &emailBackend + a.DB = db + server := MustNewServer(t, &a) defer server.Close() dat := url.Values{} @@ -117,15 +99,14 @@ func TestJoin(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusFound, "") var account database.Account - testutils.MustExec(t, testutils.DB.Where("email = ?", tc.email).First(&account), "finding account") + testutils.MustExec(t, db.Where("email = ?", tc.email).First(&account), "finding account") assert.Equal(t, account.Email.String, tc.email, "Email mismatch") assert.NotEqual(t, account.UserID, 0, "UserID mismatch") passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password)) assert.Equal(t, passwordErr, nil, "Password mismatch") var user database.User - testutils.MustExec(t, testutils.DB.Where("id = ?", account.UserID).First(&user), "finding user") - assert.Equal(t, user.Cloud, tc.expectedPro, "Cloud mismatch") + testutils.MustExec(t, db.Where("id = ?", account.UserID).First(&user), "finding user") assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch") // welcome email @@ -133,20 +114,20 @@ func TestJoin(t *testing.T) { assert.DeepEqual(t, emailBackend.Emails[0].To, []string{tc.email}, "email to mismatch") // after register, should sign in user - assertResponseSessionCookie(t, res) + assertResponseSessionCookie(t, db, res) }) } } func TestJoinError(t *testing.T) { t.Run("missing email", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() dat := url.Values{} @@ -160,21 +141,21 @@ func TestJoinError(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch") var accountCount, userCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account") - testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user") + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") assert.Equal(t, accountCount, int64(0), "accountCount mismatch") assert.Equal(t, userCount, int64(0), "userCount mismatch") }) t.Run("missing password", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() dat := url.Values{} @@ -188,21 +169,21 @@ func TestJoinError(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch") var accountCount, userCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account") - testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user") + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") assert.Equal(t, accountCount, int64(0), "accountCount mismatch") assert.Equal(t, userCount, int64(0), "userCount mismatch") }) t.Run("password confirmation mismatch", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() dat := url.Values{} @@ -218,8 +199,8 @@ func TestJoinError(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch") var accountCount, userCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account") - testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user") + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") assert.Equal(t, accountCount, int64(0), "accountCount mismatch") assert.Equal(t, userCount, int64(0), "userCount mismatch") @@ -227,17 +208,17 @@ func TestJoinError(t *testing.T) { } func TestJoinDuplicateEmail(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - testutils.SetupAccountData(u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db) + testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") dat := url.Values{} dat.Set("email", "alice@example.com") @@ -252,12 +233,12 @@ func TestJoinDuplicateEmail(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch") var accountCount, userCount, verificationTokenCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account") - testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user") - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token") + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&verificationTokenCount), "counting verification token") var user database.User - testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user") + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user") assert.Equal(t, accountCount, int64(1), "account count mismatch") assert.Equal(t, userCount, int64(1), "user count mismatch") @@ -266,15 +247,14 @@ func TestJoinDuplicateEmail(t *testing.T) { } func TestJoinDisabled(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{ - DisableRegistration: true, - }, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + a.DisableRegistration = true + server := MustNewServer(t, &a) defer server.Close() dat := url.Values{} @@ -289,8 +269,8 @@ func TestJoinDisabled(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusNotFound, "status code mismatch") var accountCount, userCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Account{}).Count(&accountCount), "counting account") - testutils.MustExec(t, testutils.DB.Model(&database.User{}).Count(&userCount), "counting user") + testutils.MustExec(t, db.Model(&database.Account{}).Count(&accountCount), "counting account") + testutils.MustExec(t, db.Model(&database.User{}).Count(&userCount), "counting user") assert.Equal(t, accountCount, int64(0), "account count mismatch") assert.Equal(t, userCount, int64(0), "user count mismatch") @@ -298,16 +278,16 @@ func TestJoinDisabled(t *testing.T) { func TestLogin(t *testing.T) { testutils.RunForWebAndAPI(t, "success", func(t *testing.T, target testutils.EndpointType) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) - u := testutils.SetupUserData() - testutils.SetupAccountData(u, "alice@example.com", "pass1234") + u := testutils.SetupUserData(db) + testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") defer server.Close() // Execute @@ -332,11 +312,11 @@ func TestLogin(t *testing.T) { } var user database.User - testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user") + testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user") assert.NotEqual(t, user.LastLoginAt, nil, "LastLoginAt mismatch") if target == testutils.EndpointWeb { - assertResponseSessionCookie(t, res) + assertResponseSessionCookie(t, db, res) } else { // after register, should sign in user var got SessionResponse @@ -346,28 +326,28 @@ func TestLogin(t *testing.T) { var sessionCount int64 var session database.Session - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session") - testutils.MustExec(t, testutils.DB.First(&session), "getting session") + testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session") + testutils.MustExec(t, db.First(&session), "getting session") assert.Equal(t, sessionCount, int64(1), "sessionCount mismatch") assert.Equal(t, got.Key, session.Key, "session Key mismatch") assert.Equal(t, got.ExpiresAt, session.ExpiresAt.Unix(), "session ExpiresAt mismatch") - assertResponseSessionCookie(t, res) + assertResponseSessionCookie(t, db, res) } }) testutils.RunForWebAndAPI(t, "wrong password", func(t *testing.T, target testutils.EndpointType) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) - u := testutils.SetupUserData() - testutils.SetupAccountData(u, "alice@example.com", "pass1234") + u := testutils.SetupUserData(db) + testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") defer server.Close() var req *http.Request @@ -388,26 +368,26 @@ func TestLogin(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "") var user database.User - testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user") + testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user") assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch") var sessionCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session") + testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session") assert.Equal(t, sessionCount, int64(0), "sessionCount mismatch") }) testutils.RunForWebAndAPI(t, "wrong email", func(t *testing.T, target testutils.EndpointType) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - testutils.SetupAccountData(u, "alice@example.com", "pass1234") + u := testutils.SetupUserData(db) + testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") var req *http.Request if target == testutils.EndpointWeb { @@ -427,22 +407,22 @@ func TestLogin(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "") var user database.User - testutils.MustExec(t, testutils.DB.Model(&database.User{}).First(&user), "finding user") + testutils.MustExec(t, db.Model(&database.User{}).First(&user), "finding user") assert.DeepEqual(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch") var sessionCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session") + testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session") assert.Equal(t, sessionCount, int64(0), "sessionCount mismatch") }) testutils.RunForWebAndAPI(t, "nonexistent email", func(t *testing.T, target testutils.EndpointType) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() var req *http.Request @@ -463,22 +443,22 @@ func TestLogin(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "") var sessionCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session") + testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session") assert.Equal(t, sessionCount, int64(0), "sessionCount mismatch") }) } func TestLogout(t *testing.T) { - setupLogoutTest := func(t *testing.T) (*httptest.Server, *database.Session, *database.Session) { + setupLogoutTest := func(t *testing.T, db *gorm.DB) (*httptest.Server, *database.Session, *database.Session) { // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) - aliceUser := testutils.SetupUserData() - testutils.SetupAccountData(aliceUser, "alice@example.com", "pass1234") - anotherUser := testutils.SetupUserData() + aliceUser := testutils.SetupUserData(db) + testutils.SetupAccountData(db, aliceUser, "alice@example.com", "pass1234") + anotherUser := testutils.SetupUserData(db) session1ExpiresAt := time.Now().Add(time.Hour * 24) session1 := database.Session{ @@ -486,21 +466,21 @@ func TestLogout(t *testing.T) { UserID: aliceUser.ID, ExpiresAt: session1ExpiresAt, } - testutils.MustExec(t, testutils.DB.Save(&session1), "preparing session1") + testutils.MustExec(t, db.Save(&session1), "preparing session1") session2 := database.Session{ Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=", UserID: anotherUser.ID, ExpiresAt: time.Now().Add(time.Hour * 24), } - testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session2") + testutils.MustExec(t, db.Save(&session2), "preparing session2") return server, &session1, &session2 } testutils.RunForWebAndAPI(t, "authenticated", func(t *testing.T, target testutils.EndpointType) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - server, session1, _ := setupLogoutTest(t) + server, session1, _ := setupLogoutTest(t, db) defer server.Close() // Execute @@ -525,8 +505,8 @@ func TestLogout(t *testing.T) { var sessionCount int64 var s2 database.Session - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session") - testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2") + testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session") + testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2") assert.Equal(t, sessionCount, int64(1), "sessionCount mismatch") @@ -542,9 +522,9 @@ func TestLogout(t *testing.T) { }) testutils.RunForWebAndAPI(t, "unauthenticated", func(t *testing.T, target testutils.EndpointType) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - server, _, _ := setupLogoutTest(t) + server, _, _ := setupLogoutTest(t, db) defer server.Close() // Execute @@ -567,9 +547,9 @@ func TestLogout(t *testing.T) { var sessionCount int64 var postSession1, postSession2 database.Session - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Count(&sessionCount), "counting session") - testutils.MustExec(t, testutils.DB.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1") - testutils.MustExec(t, testutils.DB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2") + testutils.MustExec(t, db.Model(&database.Session{}).Count(&sessionCount), "counting session") + testutils.MustExec(t, db.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1") + testutils.MustExec(t, db.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2") // two existing sessions should remain assert.Equal(t, sessionCount, int64(2), "sessionCount mismatch") @@ -581,46 +561,46 @@ func TestLogout(t *testing.T) { func TestResetPassword(t *testing.T) { t.Run("success", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword") tok := database.Token{ UserID: u.ID, Value: "MivFxYiSMMA4An9dP24DNQ==", Type: database.TokenTypeResetPassword, } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") + testutils.MustExec(t, db.Save(&tok), "preparing token") otherTok := database.Token{ UserID: u.ID, Value: "somerandomvalue", Type: database.TokenTypeEmailVerification, } - testutils.MustExec(t, testutils.DB.Save(&otherTok), "preparing another token") + testutils.MustExec(t, db.Save(&otherTok), "preparing another token") s1 := database.Session{ Key: "some-session-key-1", UserID: u.ID, ExpiresAt: time.Now().Add(time.Hour * 10 * 24), } - testutils.MustExec(t, testutils.DB.Save(&s1), "preparing user session 1") + testutils.MustExec(t, db.Save(&s1), "preparing user session 1") s2 := &database.Session{ Key: "some-session-key-2", UserID: u.ID, ExpiresAt: time.Now().Add(time.Hour * 10 * 24), } - testutils.MustExec(t, testutils.DB.Save(&s2), "preparing user session 2") + testutils.MustExec(t, db.Save(&s2), "preparing user session 2") - anotherUser := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Save(&database.Session{ + anotherUser := testutils.SetupUserData(db) + testutils.MustExec(t, db.Save(&database.Session{ Key: "some-session-key-3", UserID: anotherUser.ID, ExpiresAt: time.Now().Add(time.Hour * 10 * 24), @@ -640,9 +620,9 @@ func TestResetPassword(t *testing.T) { var resetToken, verificationToken database.Token var account database.Account - testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token") - testutils.MustExec(t, testutils.DB.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token") - testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token") + testutils.MustExec(t, db.Where("value = ?", "somerandomvalue").First(&verificationToken), "finding reset token") + testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account") assert.NotEqual(t, resetToken.UsedAt, nil, "reset_token UsedAt mismatch") passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword")) @@ -650,38 +630,38 @@ func TestResetPassword(t *testing.T) { assert.Equal(t, verificationToken.UsedAt, (*time.Time)(nil), "verificationToken UsedAt mismatch") var s1Count, s2Count int64 - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Where("id = ?", s1.ID).Count(&s1Count), "counting s1") - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Where("id = ?", s2.ID).Count(&s2Count), "counting s2") + testutils.MustExec(t, db.Model(&database.Session{}).Where("id = ?", s1.ID).Count(&s1Count), "counting s1") + testutils.MustExec(t, db.Model(&database.Session{}).Where("id = ?", s2.ID).Count(&s2Count), "counting s2") assert.Equal(t, s1Count, int64(0), "s1 should have been deleted") assert.Equal(t, s2Count, int64(0), "s2 should have been deleted") var userSessionCount, anotherUserSessionCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Where("user_id = ?", u.ID).Count(&userSessionCount), "counting user session") - testutils.MustExec(t, testutils.DB.Model(&database.Session{}).Where("user_id = ?", anotherUser.ID).Count(&anotherUserSessionCount), "counting anotherUser session") + testutils.MustExec(t, db.Model(&database.Session{}).Where("user_id = ?", u.ID).Count(&userSessionCount), "counting user session") + testutils.MustExec(t, db.Model(&database.Session{}).Where("user_id = ?", anotherUser.ID).Count(&anotherUserSessionCount), "counting anotherUser session") assert.Equal(t, userSessionCount, int64(0), "should have deleted a user session") assert.Equal(t, anotherUserSessionCount, int64(1), "anotherUser session count mismatch") }) t.Run("nonexistent token", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") tok := database.Token{ UserID: u.ID, Value: "MivFxYiSMMA4An9dP24DNQ==", Type: database.TokenTypeResetPassword, } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") + testutils.MustExec(t, db.Save(&tok), "preparing token") dat := url.Values{} dat.Set("token", "-ApMnyvpg59uOU5b-Kf5uQ==") @@ -697,33 +677,33 @@ func TestResetPassword(t *testing.T) { var resetToken database.Token var account database.Account - testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token") - testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "finding reset token") + testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "finding account") - assert.Equal(t, a.Password, account.Password, "password should not have been updated") - assert.Equal(t, a.Password, account.Password, "password should not have been updated") + assert.Equal(t, acc.Password, account.Password, "password should not have been updated") + assert.Equal(t, acc.Password, account.Password, "password should not have been updated") assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil") }) t.Run("expired token", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") tok := database.Token{ UserID: u.ID, Value: "MivFxYiSMMA4An9dP24DNQ==", Type: database.TokenTypeResetPassword, } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") - testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at") + testutils.MustExec(t, db.Save(&tok), "preparing token") + testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at") dat := url.Values{} dat.Set("token", "MivFxYiSMMA4An9dP24DNQ==") @@ -739,24 +719,24 @@ func TestResetPassword(t *testing.T) { var resetToken database.Token var account database.Account - testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token") - testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account") - assert.Equal(t, a.Password, account.Password, "password should not have been updated") + testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token") + testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account") + assert.Equal(t, acc.Password, account.Password, "password should not have been updated") assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil") }) t.Run("used token", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") usedAt := time.Now().Add(time.Hour * -11).UTC() tok := database.Token{ @@ -765,8 +745,8 @@ func TestResetPassword(t *testing.T) { Type: database.TokenTypeResetPassword, UsedAt: &usedAt, } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") - testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at") + testutils.MustExec(t, db.Save(&tok), "preparing token") + testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at") dat := url.Values{} dat.Set("token", "MivFxYiSMMA4An9dP24DNQ==") @@ -782,9 +762,9 @@ func TestResetPassword(t *testing.T) { var resetToken database.Token var account database.Account - testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token") - testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account") - assert.Equal(t, a.Password, account.Password, "password should not have been updated") + testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token") + testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account") + assert.Equal(t, acc.Password, account.Password, "password should not have been updated") resetTokenUsedAtUTC := resetToken.UsedAt.UTC() if resetTokenUsedAtUTC.Year() != usedAt.Year() || @@ -798,24 +778,24 @@ func TestResetPassword(t *testing.T) { }) t.Run("using wrong type token: email_verification", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") tok := database.Token{ UserID: u.ID, Value: "MivFxYiSMMA4An9dP24DNQ==", Type: database.TokenTypeEmailVerification, } - testutils.MustExec(t, testutils.DB.Save(&tok), "Failed to prepare reset_token") - testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at") + testutils.MustExec(t, db.Save(&tok), "Failed to prepare reset_token") + testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-11)), "Failed to prepare reset_token created_at") dat := url.Values{} dat.Set("token", "MivFxYiSMMA4An9dP24DNQ==") @@ -831,27 +811,27 @@ func TestResetPassword(t *testing.T) { var resetToken database.Token var account database.Account - testutils.MustExec(t, testutils.DB.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token") - testutils.MustExec(t, testutils.DB.Where("id = ?", a.ID).First(&account), "failed to find account") + testutils.MustExec(t, db.Where("value = ?", "MivFxYiSMMA4An9dP24DNQ==").First(&resetToken), "failed to find reset_token") + testutils.MustExec(t, db.Where("id = ?", acc.ID).First(&account), "failed to find account") - assert.Equal(t, a.Password, account.Password, "password should not have been updated") + assert.Equal(t, acc.Password, account.Password, "password should not have been updated") assert.Equal(t, resetToken.UsedAt, (*time.Time)(nil), "used_at should be nil") }) } func TestCreateResetToken(t *testing.T) { t.Run("success", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - testutils.SetupAccountData(u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db) + testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") // Execute dat := url.Values{} @@ -864,10 +844,10 @@ func TestCreateResetToken(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismtach") var tokenCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens") var resetToken database.Token - testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token") + testutils.MustExec(t, db.Where("user_id = ? AND type = ?", u.ID, database.TokenTypeResetPassword).First(&resetToken), "finding reset token") assert.Equal(t, tokenCount, int64(1), "reset_token count mismatch") assert.NotEqual(t, resetToken.Value, nil, "reset_token value mismatch") @@ -875,17 +855,17 @@ func TestCreateResetToken(t *testing.T) { }) t.Run("nonexistent email", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - testutils.SetupAccountData(u, "alice@example.com", "somepassword") + u := testutils.SetupUserData(db) + testutils.SetupAccountData(db, u, "alice@example.com", "somepassword") // Execute dat := url.Values{} @@ -898,24 +878,24 @@ func TestCreateResetToken(t *testing.T) { assert.StatusCodeEquals(t, res, http.StatusOK, "Status code mismtach") var tokenCount int64 - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting tokens") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting tokens") assert.Equal(t, tokenCount, int64(0), "reset_token count mismatch") }) } func TestUpdatePassword(t *testing.T) { t.Run("success", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@example.com", "oldpassword") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "oldpassword") // Execute dat := url.Values{} @@ -924,29 +904,29 @@ func TestUpdatePassword(t *testing.T) { dat.Set("new_password_confirmation", "newpassword") req := testutils.MakeFormReq(server.URL, "PATCH", "/account/password", dat) - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismsatch") var account database.Account - testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account") passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte("newpassword")) assert.Equal(t, passwordErr, nil, "Password mismatch") }) t.Run("old password mismatch", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword") // Execute dat := url.Values{} @@ -955,28 +935,28 @@ func TestUpdatePassword(t *testing.T) { dat.Set("new_password_confirmation", "newpassword") req := testutils.MakeFormReq(server.URL, "PATCH", "/account/password", dat) - res := testutils.HTTPAuthDo(t, req, u) + res := testutils.HTTPAuthDo(t, db, req, u) // Test assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch") var account database.Account - testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account") - assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated") + testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") + assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated") }) t.Run("password too short", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword") // Execute dat := url.Values{} @@ -985,28 +965,28 @@ func TestUpdatePassword(t *testing.T) { dat.Set("new_password_confirmation", "a") req := testutils.MakeFormReq(server.URL, "PATCH", "/account/password", dat) - res := testutils.HTTPAuthDo(t, req, u) + res := testutils.HTTPAuthDo(t, db, req, u) // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch") var account database.Account - testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account") - assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated") + testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") + assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated") }) t.Run("password confirmation mismatch", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "oldpassword") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "oldpassword") // Execute dat := url.Values{} @@ -1015,32 +995,32 @@ func TestUpdatePassword(t *testing.T) { dat.Set("new_password_confirmation", "newpassword2") req := testutils.MakeFormReq(server.URL, "PATCH", "/account/password", dat) - res := testutils.HTTPAuthDo(t, req, u) + res := testutils.HTTPAuthDo(t, db, req, u) // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status code mismsatch") var account database.Account - testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account") - assert.Equal(t, a.Password.String, account.Password.String, "password should not have been updated") + testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") + assert.Equal(t, acc.Password.String, account.Password.String, "password should not have been updated") }) } func TestUpdateEmail(t *testing.T) { t.Run("success", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "pass1234") - a.EmailVerified = true - testutils.MustExec(t, testutils.DB.Save(&a), "updating email_verified") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") + acc.EmailVerified = true + testutils.MustExec(t, db.Save(&acc), "updating email_verified") // Execute dat := url.Values{} @@ -1048,34 +1028,34 @@ func TestUpdateEmail(t *testing.T) { dat.Set("password", "pass1234") req := testutils.MakeFormReq(server.URL, "PATCH", "/account/profile", dat) - res := testutils.HTTPAuthDo(t, req, u) + res := testutils.HTTPAuthDo(t, db, req, u) // Test assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch") var user database.User var account database.Account - testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user") - testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user") + testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") assert.Equal(t, account.Email.String, "alice-new@example.com", "email mismatch") assert.Equal(t, account.EmailVerified, false, "EmailVerified mismatch") }) t.Run("password mismatch", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - u := testutils.SetupUserData() - a := testutils.SetupAccountData(u, "alice@example.com", "pass1234") - a.EmailVerified = true - testutils.MustExec(t, testutils.DB.Save(&a), "updating email_verified") + u := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, u, "alice@example.com", "pass1234") + acc.EmailVerified = true + testutils.MustExec(t, db.Save(&acc), "updating email_verified") // Execute dat := url.Values{} @@ -1083,15 +1063,15 @@ func TestUpdateEmail(t *testing.T) { dat.Set("password", "wrongpassword") req := testutils.MakeFormReq(server.URL, "PATCH", "/account/profile", dat) - res := testutils.HTTPAuthDo(t, req, u) + res := testutils.HTTPAuthDo(t, db, req, u) // Test assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "Status code mismsatch") var user database.User var account database.Account - testutils.MustExec(t, testutils.DB.Where("id = ?", u.ID).First(&user), "finding user") - testutils.MustExec(t, testutils.DB.Where("user_id = ?", u.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("id = ?", u.ID).First(&user), "finding user") + testutils.MustExec(t, db.Where("user_id = ?", u.ID).First(&account), "finding account") assert.Equal(t, account.Email.String, "alice@example.com", "email mismatch") assert.Equal(t, account.EmailVerified, true, "EmailVerified mismatch") @@ -1100,27 +1080,27 @@ func TestUpdateEmail(t *testing.T) { func TestVerifyEmail(t *testing.T) { t.Run("success", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@example.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "pass1234") tok := database.Token{ UserID: user.ID, Type: database.TokenTypeEmailVerification, Value: "someTokenValue", } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") + testutils.MustExec(t, db.Save(&tok), "preparing token") // Execute req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/verify-email/%s", "someTokenValue"), "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusFound, "Status code mismatch") @@ -1128,9 +1108,9 @@ func TestVerifyEmail(t *testing.T) { var account database.Account var token database.Token var tokenCount int64 - testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account") - testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token") + testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token") assert.Equal(t, account.EmailVerified, true, "email_verified mismatch") assert.NotEqual(t, token.Value, "", "token value should not have been updated") @@ -1139,17 +1119,17 @@ func TestVerifyEmail(t *testing.T) { }) t.Run("used token", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@example.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "pass1234") usedAt := time.Now().Add(time.Hour * -11).UTC() tok := database.Token{ @@ -1158,11 +1138,11 @@ func TestVerifyEmail(t *testing.T) { Value: "someTokenValue", UsedAt: &usedAt, } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") + testutils.MustExec(t, db.Save(&tok), "preparing token") // Execute req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/verify-email/%s", "someTokenValue"), "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "") @@ -1170,9 +1150,9 @@ func TestVerifyEmail(t *testing.T) { var account database.Account var token database.Token var tokenCount int64 - testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account") - testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token") + testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token") assert.Equal(t, account.EmailVerified, false, "email_verified mismatch") assert.NotEqual(t, token.UsedAt, nil, "token used_at mismatch") @@ -1181,29 +1161,29 @@ func TestVerifyEmail(t *testing.T) { }) t.Run("expired token", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@example.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "pass1234") tok := database.Token{ UserID: user.ID, Type: database.TokenTypeEmailVerification, Value: "someTokenValue", } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") - testutils.MustExec(t, testutils.DB.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at") + testutils.MustExec(t, db.Save(&tok), "preparing token") + testutils.MustExec(t, db.Model(&tok).Update("created_at", time.Now().Add(time.Minute*-31)), "Failed to prepare token created_at") // Execute req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/verify-email/%s", "someTokenValue"), "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusGone, "") @@ -1211,9 +1191,9 @@ func TestVerifyEmail(t *testing.T) { var account database.Account var token database.Token var tokenCount int64 - testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account") - testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token") + testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token") assert.Equal(t, account.EmailVerified, false, "email_verified mismatch") assert.Equal(t, tokenCount, int64(1), "token count mismatch") @@ -1221,30 +1201,30 @@ func TestVerifyEmail(t *testing.T) { }) t.Run("already verified", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - a := testutils.SetupAccountData(user, "alice@example.com", "oldpass1234") - a.EmailVerified = true - testutils.MustExec(t, testutils.DB.Save(&a), "preparing account") + user := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, user, "alice@example.com", "oldpass1234") + acc.EmailVerified = true + testutils.MustExec(t, db.Save(&acc), "preparing account") tok := database.Token{ UserID: user.ID, Type: database.TokenTypeEmailVerification, Value: "someTokenValue", } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") + testutils.MustExec(t, db.Save(&tok), "preparing token") // Execute req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/verify-email/%s", "someTokenValue"), "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusConflict, "") @@ -1252,9 +1232,9 @@ func TestVerifyEmail(t *testing.T) { var account database.Account var token database.Token var tokenCount int64 - testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account") - testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token") + testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token") assert.Equal(t, account.EmailVerified, true, "email_verified mismatch") assert.Equal(t, tokenCount, int64(1), "token count mismatch") @@ -1264,23 +1244,23 @@ func TestVerifyEmail(t *testing.T) { func TestCreateVerificationToken(t *testing.T) { t.Run("success", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup emailBackend := testutils.MockEmailbackendImplementation{} - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - EmailBackend: &emailBackend, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + a.EmailBackend = &emailBackend + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@example.com", "pass1234") + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@example.com", "pass1234") // Execute req := testutils.MakeReq(server.URL, "POST", "/verification-token", "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusFound, "status code mismatch") @@ -1288,9 +1268,9 @@ func TestCreateVerificationToken(t *testing.T) { var account database.Account var token database.Token var tokenCount int64 - testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account") - testutils.MustExec(t, testutils.DB.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token") + testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account") + testutils.MustExec(t, db.Where("user_id = ? AND type = ?", user.ID, database.TokenTypeEmailVerification).First(&token), "finding token") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token") assert.Equal(t, account.EmailVerified, false, "email_verified should not have been updated") assert.NotEqual(t, token.Value, "", "token Value mismatch") @@ -1300,30 +1280,30 @@ func TestCreateVerificationToken(t *testing.T) { }) t.Run("already verified", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Setup - server := MustNewServer(t, &app.App{ - Clock: clock.NewMock(), - Config: config.Config{}, - }) + a := app.NewTest() + a.Clock = clock.NewMock() + a.DB = db + server := MustNewServer(t, &a) defer server.Close() - user := testutils.SetupUserData() - a := testutils.SetupAccountData(user, "alice@example.com", "pass1234") - a.EmailVerified = true - testutils.MustExec(t, testutils.DB.Save(&a), "preparing account") + user := testutils.SetupUserData(db) + acc := testutils.SetupAccountData(db, user, "alice@example.com", "pass1234") + acc.EmailVerified = true + testutils.MustExec(t, db.Save(&acc), "preparing account") // Execute req := testutils.MakeReq(server.URL, "POST", "/verification-token", "") - res := testutils.HTTPAuthDo(t, req, user) + res := testutils.HTTPAuthDo(t, db, req, user) // Test assert.StatusCodeEquals(t, res, http.StatusConflict, "Status code mismatch") var account database.Account var tokenCount int64 - testutils.MustExec(t, testutils.DB.Where("user_id = ?", user.ID).First(&account), "finding account") - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&tokenCount), "counting token") + testutils.MustExec(t, db.Where("user_id = ?", user.ID).First(&account), "finding account") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&tokenCount), "counting token") assert.Equal(t, account.EmailVerified, true, "email_verified should not have been updated") assert.Equal(t, tokenCount, int64(0), "token count mismatch") diff --git a/pkg/server/database/consts.go b/pkg/server/database/consts.go index 38c4d536..b4a1db03 100644 --- a/pkg/server/database/consts.go +++ b/pkg/server/database/consts.go @@ -23,8 +23,6 @@ const ( TokenTypeResetPassword = "reset_password" // TokenTypeEmailVerification is a type of a token for verifying email TokenTypeEmailVerification = "email_verification" - // TokenTypeEmailPreference is a type of a token for updating email preference - TokenTypeEmailPreference = "email_preference" ) const ( diff --git a/pkg/server/database/database.go b/pkg/server/database/database.go index 3c5b6d9b..eaab7c50 100644 --- a/pkg/server/database/database.go +++ b/pkg/server/database/database.go @@ -19,9 +19,11 @@ package database import ( - "github.com/dnote/dnote/pkg/server/config" + "os" + "path/filepath" + "github.com/pkg/errors" - "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" "gorm.io/gorm" ) @@ -32,18 +34,12 @@ var ( // InitSchema migrates database schema to reflect the latest model definition func InitSchema(db *gorm.DB) { - if err := db.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`).Error; err != nil { - panic(err) - } - if err := db.AutoMigrate( &User{}, &Account{}, &Book{}, &Note{}, - &Notification{}, &Token{}, - &EmailPreference{}, &Session{}, ); err != nil { panic(err) @@ -51,8 +47,14 @@ func InitSchema(db *gorm.DB) { } // Open initializes the database connection -func Open(c config.Config) *gorm.DB { - db, err := gorm.Open(postgres.Open(c.DB.GetConnectionStr()), &gorm.Config{}) +func Open(dbPath string) *gorm.DB { + // Create directory if it doesn't exist + dir := filepath.Dir(dbPath) + if err := os.MkdirAll(dir, 0755); err != nil { + panic(errors.Wrapf(err, "creating database directory at %s", dir)) + } + + db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) if err != nil { panic(errors.Wrap(err, "opening database conection")) } diff --git a/pkg/server/database/migrate.go b/pkg/server/database/migrate.go index 8b258b12..4250fd80 100644 --- a/pkg/server/database/migrate.go +++ b/pkg/server/database/migrate.go @@ -19,34 +19,167 @@ package database import ( - "log" - "net/http" + "fmt" + "io/fs" + "sort" + "strings" "github.com/dnote/dnote/pkg/server/database/migrations" - "gorm.io/gorm" + "github.com/dnote/dnote/pkg/server/log" "github.com/pkg/errors" - "github.com/rubenv/sql-migrate" + "gorm.io/gorm" ) -// Migrate runs the migrations -func Migrate(db *gorm.DB) error { - migrations := &migrate.HttpFileSystemMigrationSource{ - FileSystem: http.FileSystem(http.FS(migrations.Files)), +type migrationFile struct { + filename string + version int +} + +// validateMigrationFilename checks if filename follows format: NNN-description.sql +func validateMigrationFilename(name string) error { + // Check .sql extension + if !strings.HasSuffix(name, ".sql") { + return errors.Errorf("invalid migration filename: must end with .sql") } - migrate.SetTable(MigrationTableName) - - sqlDB, err := db.DB() - if err != nil { - return errors.Wrap(err, "getting underlying sql.DB") + name = strings.TrimSuffix(name, ".sql") + parts := strings.SplitN(name, "-", 2) + if len(parts) != 2 { + return errors.Errorf("invalid migration filename: must be NNN-description.sql") } - n, err := migrate.Exec(sqlDB, "postgres", migrations, migrate.Up) - if err != nil { - return errors.Wrap(err, "running migrations") + version, description := parts[0], parts[1] + + // Validate version is 3 digits + if len(version) != 3 { + return errors.Errorf("invalid migration filename: version must be 3 digits, got %s", version) + } + for _, c := range version { + if c < '0' || c > '9' { + return errors.Errorf("invalid migration filename: version must be numeric, got %s", version) + } } - log.Printf("Performed %d migrations", n) + // Validate description is not empty + if description == "" { + return errors.Errorf("invalid migration filename: description is required") + } + + return nil +} + +// Migrate runs the migrations using the embedded migration files +func Migrate(db *gorm.DB) error { + return migrate(db, migrations.Files) +} + +// getMigrationFiles reads, validates, and sorts migration files +func getMigrationFiles(fsys fs.FS) ([]migrationFile, error) { + entries, err := fs.ReadDir(fsys, ".") + if err != nil { + return nil, errors.Wrap(err, "reading migration directory") + } + + var migrations []migrationFile + seen := make(map[int]string) + for _, e := range entries { + name := e.Name() + + if err := validateMigrationFilename(name); err != nil { + return nil, err + } + + // Parse version + var v int + fmt.Sscanf(name, "%d", &v) + + // Check for duplicate version numbers + if existing, found := seen[v]; found { + return nil, errors.Errorf("duplicate migration version %d: %s and %s", v, existing, name) + } + seen[v] = name + + migrations = append(migrations, migrationFile{ + filename: name, + version: v, + }) + } + + // Sort by version + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].version < migrations[j].version + }) + + return migrations, nil +} + +// migrate runs migrations from the provided filesystem +func migrate(db *gorm.DB, fsys fs.FS) error { + if err := db.Exec(` + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + `).Error; err != nil { + return errors.Wrap(err, "initializing migration table") + } + + // Get current version + var version int + if err := db.Raw("SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(&version).Error; err != nil { + return errors.Wrap(err, "reading current version") + } + + // Read and validate migration files + migrations, err := getMigrationFiles(fsys) + if err != nil { + return err + } + + var filenames []string + for _, m := range migrations { + filenames = append(filenames, m.filename) + } + + log.WithFields(log.Fields{ + "version": version, + }).Info("Database schema version.") + + log.WithFields(log.Fields{ + "files": filenames, + }).Debug("Database migration files.") + + // Apply pending migrations + for _, m := range migrations { + if m.version <= version { + continue + } + + log.WithFields(log.Fields{ + "file": m.filename, + }).Info("Applying migration.") + + sql, err := fs.ReadFile(fsys, m.filename) + if err != nil { + return errors.Wrapf(err, "reading migration file %s", m.filename) + } + + if len(strings.TrimSpace(string(sql))) == 0 { + return errors.Errorf("migration file %s is empty", m.filename) + } + + if err := db.Exec(string(sql)).Error; err != nil { + return fmt.Errorf("migration %s failed: %w", m.filename, err) + } + + if err := db.Exec("INSERT INTO schema_migrations (version) VALUES (?)", m.version).Error; err != nil { + return errors.Wrapf(err, "recording migration %s", m.filename) + } + + log.WithFields(log.Fields{ + "file": m.filename, + }).Info("Migrate success.") + } return nil } diff --git a/pkg/server/database/migrate/main.go b/pkg/server/database/migrate/main.go deleted file mode 100644 index 932f4ec3..00000000 --- a/pkg/server/database/migrate/main.go +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors - * - * This file is part of Dnote. - * - * Dnote is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dnote is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Dnote. If not, see . - */ - -package main - -import ( - "flag" - "fmt" - "os" - - "github.com/dnote/dnote/pkg/server/config" - "github.com/dnote/dnote/pkg/server/database" - "github.com/joho/godotenv" - "github.com/pkg/errors" - "github.com/rubenv/sql-migrate" -) - -var ( - migrationDir = flag.String("migrationDir", "../migrations", "the path to the directory with migraiton files") -) - -func init() { - fmt.Println("Migrating Dnote database...") - - // Load env - if os.Getenv("GO_ENV") != "PRODUCTION" { - if err := godotenv.Load("../../.env.dev"); err != nil { - panic(err) - } - } - -} - -func main() { - flag.Parse() - - c := config.Load() - db := database.Open(c) - - migrations := &migrate.FileMigrationSource{ - Dir: *migrationDir, - } - - migrate.SetTable("migrations") - - sqlDB, err := db.DB() - if err != nil { - panic(errors.Wrap(err, "getting underlying sql.DB")) - } - - n, err := migrate.Exec(sqlDB, "postgres", migrations, migrate.Up) - if err != nil { - panic(errors.Wrap(err, "executing migrations")) - } - - fmt.Printf("Applied %d migrations\n", n) -} diff --git a/pkg/server/database/migrate_test.go b/pkg/server/database/migrate_test.go new file mode 100644 index 00000000..31853730 --- /dev/null +++ b/pkg/server/database/migrate_test.go @@ -0,0 +1,313 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package database + +import ( + "io/fs" + "testing" + "testing/fstest" + + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// unsortedFS wraps fstest.MapFS to return entries in reverse order +type unsortedFS struct { + fstest.MapFS +} + +func (u unsortedFS) ReadDir(name string) ([]fs.DirEntry, error) { + entries, err := u.MapFS.ReadDir(name) + if err != nil { + return nil, err + } + // Reverse the entries to ensure they're not in sorted order + for i, j := 0, len(entries)-1; i < j; i, j = i+1, j-1 { + entries[i], entries[j] = entries[j], entries[i] + } + return entries, nil +} + +// errorFS returns an error on ReadDir +type errorFS struct{} + +func (e errorFS) Open(name string) (fs.File, error) { + return nil, fs.ErrNotExist +} + +func (e errorFS) ReadDir(name string) ([]fs.DirEntry, error) { + return nil, fs.ErrPermission +} + +func TestMigrate_createsSchemaTable(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + migrationsFs := fstest.MapFS{} + migrate(db, migrationsFs) + + // Verify schema_migrations table exists + var count int64 + if err := db.Raw("SELECT COUNT(*) FROM schema_migrations").Scan(&count).Error; err != nil { + t.Fatalf("schema_migrations table not found: %v", err) + } +} + +func TestMigrate_idempotency(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + // Set up table before migration + if err := db.Exec("CREATE TABLE counter (value INTEGER)").Error; err != nil { + t.Fatalf("failed to create table: %v", err) + } + + // Create migration that inserts a row + migrationsFs := fstest.MapFS{ + "001-insert-data.sql": &fstest.MapFile{ + Data: []byte("INSERT INTO counter (value) VALUES (100);"), + }, + } + + // Run migration first time + if err := migrate(db, migrationsFs); err != nil { + t.Fatalf("first migration failed: %v", err) + } + var count int64 + if err := db.Raw("SELECT COUNT(*) FROM counter").Scan(&count).Error; err != nil { + t.Fatalf("failed to count rows: %v", err) + } + if count != 1 { + t.Errorf("expected 1 row, got %d", count) + } + + // Run migration second time - it should not run the SQL again + if err := migrate(db, migrationsFs); err != nil { + t.Fatalf("second migration failed: %v", err) + } + if err := db.Raw("SELECT COUNT(*) FROM counter").Scan(&count).Error; err != nil { + t.Fatalf("failed to count rows: %v", err) + } + if count != 1 { + t.Errorf("migration ran twice: expected 1 row, got %d", count) + } +} + +func TestMigrate_ordering(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + // Create table before migrations + if err := db.Exec("CREATE TABLE log (value INTEGER)").Error; err != nil { + t.Fatalf("failed to create table: %v", err) + } + + // Create migrations with unsorted filesystem + migrationsFs := unsortedFS{ + MapFS: fstest.MapFS{ + "010-tenth.sql": &fstest.MapFile{ + Data: []byte("INSERT INTO log (value) VALUES (3);"), + }, + "001-first.sql": &fstest.MapFile{ + Data: []byte("INSERT INTO log (value) VALUES (1);"), + }, + "002-second.sql": &fstest.MapFile{ + Data: []byte("INSERT INTO log (value) VALUES (2);"), + }, + }, + } + + // Run migrations + if err := migrate(db, migrationsFs); err != nil { + t.Fatalf("migration failed: %v", err) + } + + // Verify migrations ran in correct order (1, 2, 3) + var values []int + if err := db.Raw("SELECT value FROM log ORDER BY rowid").Scan(&values).Error; err != nil { + t.Fatalf("failed to query log: %v", err) + } + + expected := []int{1, 2, 3} + if len(values) != len(expected) { + t.Fatalf("expected %d rows, got %d", len(expected), len(values)) + } + + for i, v := range values { + if v != expected[i] { + t.Errorf("row %d: expected value %d, got %d", i, expected[i], v) + } + } +} + +func TestMigrate_duplicateVersion(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + // Create migrations with duplicate version numbers + migrationsFs := fstest.MapFS{ + "001-first.sql": &fstest.MapFile{ + Data: []byte("SELECT 1;"), + }, + "001-second.sql": &fstest.MapFile{ + Data: []byte("SELECT 2;"), + }, + } + + // Should return error for duplicate version + err = migrate(db, migrationsFs) + if err == nil { + t.Fatal("expected error for duplicate version numbers, got nil") + } +} + +func TestMigrate_initTableError(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + // Close the database connection to cause exec to fail + sqlDB, _ := db.DB() + sqlDB.Close() + + migrationsFs := fstest.MapFS{ + "001-init.sql": &fstest.MapFile{ + Data: []byte("SELECT 1;"), + }, + } + + // Should return error for table initialization failure + err = migrate(db, migrationsFs) + if err == nil { + t.Fatal("expected error for table initialization failure, got nil") + } +} + +func TestMigrate_readDirError(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + // Use filesystem that fails on ReadDir + err = migrate(db, errorFS{}) + if err == nil { + t.Fatal("expected error for ReadDir failure, got nil") + } +} + +func TestMigrate_sqlError(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + // Create migration with invalid SQL + migrationsFs := fstest.MapFS{ + "001-bad-sql.sql": &fstest.MapFile{ + Data: []byte("INVALID SQL SYNTAX HERE;"), + }, + } + + // Should return error for SQL execution failure + err = migrate(db, migrationsFs) + if err == nil { + t.Fatal("expected error for invalid SQL, got nil") + } +} + +func TestMigrate_emptyFile(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + tests := []struct { + name string + data string + wantErr bool + }{ + {"completely empty", "", true}, + {"only whitespace", " \n\t ", true}, + {"only comments", "-- just a comment", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + migrationsFs := fstest.MapFS{ + "001-empty.sql": &fstest.MapFile{ + Data: []byte(tt.data), + }, + } + + err = migrate(db, migrationsFs) + if (err != nil) != tt.wantErr { + t.Errorf("error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestMigrate_invalidFilename(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + tests := []struct { + name string + filename string + wantErr bool + }{ + {"valid format", "001-init.sql", false}, + {"no leading zeros", "1-init.sql", true}, + {"two digits", "01-init.sql", true}, + {"no dash", "001init.sql", true}, + {"no description", "001-.sql", true}, + {"no extension", "001-init.", true}, + {"wrong extension", "001-init.txt", true}, + {"non-numeric version number", "0a1-init.sql", true}, + {"underscore separator", "001_init.sql", true}, + {"multiple dashes in description", "001-add-feature-v2.sql", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + migrationsFs := fstest.MapFS{ + tt.filename: &fstest.MapFile{ + Data: []byte("SELECT 1;"), + }, + } + + err := migrate(db, migrationsFs) + if (err != nil) != tt.wantErr { + t.Errorf("error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/server/database/migrations/.gitkeep b/pkg/server/database/migrations/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/server/database/migrations/100-create-fts-table.sql b/pkg/server/database/migrations/100-create-fts-table.sql new file mode 100644 index 00000000..21b43704 --- /dev/null +++ b/pkg/server/database/migrations/100-create-fts-table.sql @@ -0,0 +1,18 @@ +-- Create FTS5 virtual table for full-text search on notes +CREATE VIRTUAL TABLE IF NOT EXISTS notes_fts USING fts5( + content=notes, + body, + tokenize="porter unicode61 categories 'L* N* Co Ps Pe'" +); + +-- Create triggers to keep notes_fts in sync with notes +CREATE TRIGGER IF NOT EXISTS notes_insert AFTER INSERT ON notes BEGIN + INSERT INTO notes_fts(rowid, body) VALUES (new.rowid, new.body); +END; +CREATE TRIGGER IF NOT EXISTS notes_delete AFTER DELETE ON notes BEGIN + INSERT INTO notes_fts(notes_fts, rowid, body) VALUES ('delete', old.rowid, old.body); +END; +CREATE TRIGGER IF NOT EXISTS notes_update AFTER UPDATE ON notes BEGIN + INSERT INTO notes_fts(notes_fts, rowid, body) VALUES ('delete', old.rowid, old.body); + INSERT INTO notes_fts(rowid, body) VALUES (new.rowid, new.body); +END; \ No newline at end of file diff --git a/pkg/server/database/migrations/20190819115834-full-text-search.sql b/pkg/server/database/migrations/20190819115834-full-text-search.sql deleted file mode 100644 index b3d884e9..00000000 --- a/pkg/server/database/migrations/20190819115834-full-text-search.sql +++ /dev/null @@ -1,41 +0,0 @@ - --- +migrate Up - --- Configure full text search -CREATE TEXT SEARCH DICTIONARY english_nostop ( - Template = snowball, - Language = english -); - -CREATE TEXT SEARCH CONFIGURATION public.english_nostop ( COPY = pg_catalog.english ); - -ALTER TEXT SEARCH CONFIGURATION public.english_nostop -ALTER MAPPING FOR asciiword, asciihword, hword_asciipart, hword, hword_part, word WITH english_nostop; - - --- Create a trigger --- +migrate StatementBegin -CREATE OR REPLACE FUNCTION note_tsv_trigger() RETURNS trigger AS $$ -begin - new.tsv := setweight(to_tsvector('english_nostop', new.body), 'A'); - return new; -end -$$ LANGUAGE plpgsql; - -DROP TRIGGER IF EXISTS tsvectorupdate ON notes; -CREATE TRIGGER tsvectorupdate -BEFORE INSERT OR UPDATE ON notes -FOR EACH ROW EXECUTE PROCEDURE note_tsv_trigger(); --- +migrate StatementEnd - --- index tsv -CREATE INDEX IF NOT EXISTS idx_notes_tsv -ON notes -USING gin(tsv); - --- initialize tsv -UPDATE notes -SET tsv = setweight(to_tsvector('english_nostop', notes.body), 'A') -WHERE notes.encrypted = false; - --- +migrate Down diff --git a/pkg/server/database/migrations/20191028103522-create-weekly-repetition.sql b/pkg/server/database/migrations/20191028103522-create-weekly-repetition.sql deleted file mode 100644 index 22ac6f3a..00000000 --- a/pkg/server/database/migrations/20191028103522-create-weekly-repetition.sql +++ /dev/null @@ -1,8 +0,0 @@ --- this migration is noop because repetition_rules have been removed - --- create-weekly-repetition.sql creates the default repetition rules for the users --- that used to have the weekly email digest on Friday 20:00 UTC - --- +migrate Up - --- +migrate Down diff --git a/pkg/server/database/migrations/20191225185502-populate-digest-version.sql b/pkg/server/database/migrations/20191225185502-populate-digest-version.sql deleted file mode 100644 index 73098dbf..00000000 --- a/pkg/server/database/migrations/20191225185502-populate-digest-version.sql +++ /dev/null @@ -1,9 +0,0 @@ --- this migration is noop because digests have been removed - --- populate-digest-version.sql populates the `version` column for the digests --- by assigining an incremental integer scoped to a repetition rule that each --- digest belongs, ordered by created_at timestamp of the digests. - --- +migrate Up - --- +migrate Down diff --git a/pkg/server/database/migrations/20191226093447-add-digest-id-primary-key.sql b/pkg/server/database/migrations/20191226093447-add-digest-id-primary-key.sql deleted file mode 100644 index c7d17d2e..00000000 --- a/pkg/server/database/migrations/20191226093447-add-digest-id-primary-key.sql +++ /dev/null @@ -1,5 +0,0 @@ --- this migration is noop because digests have been removed - --- +migrate Up - --- +migrate Down diff --git a/pkg/server/database/migrations/20191226105659-use-id-in-digest-notes-joining-table.sql b/pkg/server/database/migrations/20191226105659-use-id-in-digest-notes-joining-table.sql deleted file mode 100644 index faec52ff..00000000 --- a/pkg/server/database/migrations/20191226105659-use-id-in-digest-notes-joining-table.sql +++ /dev/null @@ -1,8 +0,0 @@ --- this migration is noop because digests have been removed - --- -use-id-in-digest-notes-joining-table.sql replaces uuids with ids --- as foreign keys in the digest_notes joining table. - --- +migrate Up - --- +migrate Down diff --git a/pkg/server/database/migrations/20191226152111-delete-outdated-digests.sql b/pkg/server/database/migrations/20191226152111-delete-outdated-digests.sql deleted file mode 100644 index 84c8ccf3..00000000 --- a/pkg/server/database/migrations/20191226152111-delete-outdated-digests.sql +++ /dev/null @@ -1,8 +0,0 @@ --- this migration is noop because digests have been removed - --- delete-outdated-digests.sql deletes digests that do not belong to any repetition rules, --- along with digest_notes associations. - --- +migrate Up - --- +migrate Down diff --git a/pkg/server/database/migrations/20200522170529-remove-billing-columns.sql b/pkg/server/database/migrations/20200522170529-remove-billing-columns.sql deleted file mode 100644 index a814f26b..00000000 --- a/pkg/server/database/migrations/20200522170529-remove-billing-columns.sql +++ /dev/null @@ -1,9 +0,0 @@ --- remove-billing-columns.sql drops billing related columns that are now obsolete. - --- +migrate Up - -ALTER TABLE users DROP COLUMN IF EXISTS stripe_customer_id; -ALTER TABLE users DROP COLUMN IF EXISTS billing_country; - --- +migrate Down - diff --git a/pkg/server/database/models.go b/pkg/server/database/models.go index 2126c106..ac3e4e56 100644 --- a/pkg/server/database/models.go +++ b/pkg/server/database/models.go @@ -25,14 +25,14 @@ import ( // Model is the base model definition type Model struct { ID int `gorm:"primaryKey" json:"-"` - CreatedAt time.Time `json:"created_at" gorm:"default:now()"` - UpdatedAt time.Time `json:"updated_at"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` } // Book is a model for a book type Book struct { Model - UUID string `json:"uuid" gorm:"uniqueIndex;type:uuid;default:uuid_generate_v4()"` + UUID string `json:"uuid" gorm:"uniqueIndex;type:text"` UserID int `json:"user_id" gorm:"index"` Label string `json:"label" gorm:"index"` Notes []Note `json:"notes" gorm:"foreignKey:BookUUID;references:UUID"` @@ -46,15 +46,14 @@ type Book struct { // Note is a model for a note type Note struct { Model - UUID string `json:"uuid" gorm:"index;type:uuid;default:uuid_generate_v4()"` + UUID string `json:"uuid" gorm:"index;type:text"` Book Book `json:"book" gorm:"foreignKey:BookUUID;references:UUID"` User User `json:"user"` UserID int `json:"user_id" gorm:"index"` - BookUUID string `json:"book_uuid" gorm:"index;type:uuid"` + BookUUID string `json:"book_uuid" gorm:"index;type:text"` Body string `json:"content"` AddedOn int64 `json:"added_on"` EditedOn int64 `json:"edited_on"` - TSV string `json:"-" gorm:"type:tsvector"` Public bool `json:"public" gorm:"default:false"` USN int `json:"-" gorm:"index"` Deleted bool `json:"-" gorm:"default:false"` @@ -65,11 +64,10 @@ type Note struct { // User is a model for a user type User struct { Model - UUID string `json:"uuid" gorm:"type:uuid;index;default:uuid_generate_v4()"` + UUID string `json:"uuid" gorm:"type:text;index"` Account Account `gorm:"foreignKey:UserID"` LastLoginAt *time.Time `json:"-"` MaxUSN int `json:"-" gorm:"default:0"` - Cloud bool `json:"-" gorm:"default:false"` } // Account is a model for an account @@ -90,21 +88,6 @@ type Token struct { UsedAt *time.Time } -// Notification is the learning notification sent to the user -type Notification struct { - Model - Type string - UserID int `gorm:"index"` -} - -// EmailPreference is a preference per user for receiving email communication -type EmailPreference struct { - Model - UserID int `gorm:"index" json:"-"` - InactiveReminder bool `json:"inactive_reminder" gorm:"default:false"` - ProductUpdate bool `json:"product_update" gorm:"default:true"` -} - // Session represents a user session type Session struct { Model diff --git a/pkg/server/database/scripts/create-migration.sh b/pkg/server/database/scripts/create-migration.sh deleted file mode 100755 index ef168165..00000000 --- a/pkg/server/database/scripts/create-migration.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env bash -# create-migration.sh creates a new SQL migration file for the -# server side Postgres database using the sql-migrate tool. -set -eux - -is_command () { - command -v "$1" >/dev/null 2>&1; -} - -if ! is_command sql-migrate; then - echo "sql-migrate is not found. Please run install-sql-migrate.sh" - exit 1 -fi - -if [ "$#" == 0 ]; then - echo "filename not provided" - exit 1 -fi - -filename=$1 -sql-migrate new -config=./sql-migrate.yml "$filename" diff --git a/pkg/server/database/scripts/install-sql-migrate.sh b/pkg/server/database/scripts/install-sql-migrate.sh deleted file mode 100755 index 334fb817..00000000 --- a/pkg/server/database/scripts/install-sql-migrate.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash - -go get -v github.com/rubenv/sql-migrate/... diff --git a/pkg/server/database/sql-migrate.yml b/pkg/server/database/sql-migrate.yml deleted file mode 100644 index f9c90d83..00000000 --- a/pkg/server/database/sql-migrate.yml +++ /dev/null @@ -1,8 +0,0 @@ -# A configuration for sql-migrate tool for generating migrations -# using `sql-migrate new`. This file is not actually used for running -# migrations because we run them programmatically. - -development: - dialect: postgres - datasource: dbname=dnote sslmode=disable - dir: ./migrations diff --git a/pkg/server/job/job.go b/pkg/server/job/job.go deleted file mode 100644 index 12efd823..00000000 --- a/pkg/server/job/job.go +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors - * - * This file is part of Dnote. - * - * Dnote is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dnote is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Dnote. If not, see . - */ - -package job - -import ( - slog "log" - - "github.com/dnote/dnote/pkg/clock" - "github.com/dnote/dnote/pkg/server/config" - "github.com/dnote/dnote/pkg/server/mailer" - "gorm.io/gorm" - "github.com/pkg/errors" - "github.com/robfig/cron" -) - -var ( - // ErrEmptyDB is an error for missing database connection in the app configuration - ErrEmptyDB = errors.New("No database connection was provided") - // ErrEmptyClock is an error for missing clock in the app configuration - ErrEmptyClock = errors.New("No clock was provided") - // ErrEmptyWebURL is an error for missing WebURL content in the app configuration - ErrEmptyWebURL = errors.New("No WebURL was provided") - // ErrEmptyEmailTemplates is an error for missing EmailTemplates content in the app configuration - ErrEmptyEmailTemplates = errors.New("No EmailTemplate store was provided") - // ErrEmptyEmailBackend is an error for missing EmailBackend content in the app configuration - ErrEmptyEmailBackend = errors.New("No EmailBackend was provided") -) - -// Runner is a configuration for job -type Runner struct { - DB *gorm.DB - Clock clock.Clock - EmailTmpl mailer.Templates - EmailBackend mailer.Backend - Config config.Config -} - -// NewRunner returns a new runner -func NewRunner(db *gorm.DB, c clock.Clock, t mailer.Templates, b mailer.Backend, config config.Config) (Runner, error) { - ret := Runner{ - DB: db, - EmailTmpl: t, - EmailBackend: b, - Clock: c, - Config: config, - } - - if err := ret.validate(); err != nil { - return Runner{}, errors.Wrap(err, "validating runner configuration") - } - - return ret, nil -} - -func (r *Runner) validate() error { - if r.DB == nil { - return ErrEmptyDB - } - if r.Clock == nil { - return ErrEmptyClock - } - if r.EmailTmpl == nil { - return ErrEmptyEmailTemplates - } - if r.EmailBackend == nil { - return ErrEmptyEmailBackend - } - if r.Config.WebURL == "" { - return ErrEmptyWebURL - } - - return nil -} - -func scheduleJob(c *cron.Cron, spec string, cmd func()) { - s, err := cron.ParseStandard(spec) - if err != nil { - panic(errors.Wrap(err, "parsing schedule")) - } - - c.Schedule(s, cron.FuncJob(cmd)) -} - -func (r *Runner) schedule(ch chan error) { - // Schedule jobs - cr := cron.New() - cr.Start() - - ch <- nil - - // Block forever - select {} -} - -// Do starts the background tasks in a separate goroutine that runs forever -func (r *Runner) Do() error { - // validate - if err := r.validate(); err != nil { - return errors.Wrap(err, "validating job configurations") - } - - ch := make(chan error) - go r.schedule(ch) - if err := <-ch; err != nil { - return errors.Wrap(err, "scheduling jobs") - } - - slog.Println("Started background tasks") - - return nil -} diff --git a/pkg/server/job/job_test.go b/pkg/server/job/job_test.go deleted file mode 100644 index 885d429f..00000000 --- a/pkg/server/job/job_test.go +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors - * - * This file is part of Dnote. - * - * Dnote is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dnote is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Dnote. If not, see . - */ - -package job - -import ( - "fmt" - "testing" - - "github.com/dnote/dnote/pkg/assert" - "github.com/dnote/dnote/pkg/clock" - "github.com/dnote/dnote/pkg/server/config" - "github.com/dnote/dnote/pkg/server/mailer" - "github.com/dnote/dnote/pkg/server/testutils" - "gorm.io/gorm" - "github.com/pkg/errors" -) - -func TestNewRunner(t *testing.T) { - testCases := []struct { - db *gorm.DB - clock clock.Clock - emailTmpl mailer.Templates - emailBackend mailer.Backend - webURL string - expectedErr error - }{ - { - db: &gorm.DB{}, - clock: clock.NewMock(), - emailTmpl: mailer.Templates{}, - emailBackend: &testutils.MockEmailbackendImplementation{}, - webURL: "http://mock.url", - expectedErr: nil, - }, - { - db: nil, - clock: clock.NewMock(), - emailTmpl: mailer.Templates{}, - emailBackend: &testutils.MockEmailbackendImplementation{}, - webURL: "http://mock.url", - expectedErr: ErrEmptyDB, - }, - { - db: &gorm.DB{}, - clock: nil, - emailTmpl: mailer.Templates{}, - emailBackend: &testutils.MockEmailbackendImplementation{}, - webURL: "http://mock.url", - expectedErr: ErrEmptyClock, - }, - { - db: &gorm.DB{}, - clock: clock.NewMock(), - emailTmpl: nil, - emailBackend: &testutils.MockEmailbackendImplementation{}, - webURL: "http://mock.url", - expectedErr: ErrEmptyEmailTemplates, - }, - { - db: &gorm.DB{}, - clock: clock.NewMock(), - emailTmpl: mailer.Templates{}, - emailBackend: nil, - webURL: "http://mock.url", - expectedErr: ErrEmptyEmailBackend, - }, - { - db: &gorm.DB{}, - clock: clock.NewMock(), - emailTmpl: mailer.Templates{}, - emailBackend: &testutils.MockEmailbackendImplementation{}, - webURL: "", - expectedErr: ErrEmptyWebURL, - }, - } - - for idx, tc := range testCases { - t.Run(fmt.Sprintf("test case %d", idx), func(t *testing.T) { - - c := config.Load() - c.WebURL = tc.webURL - - _, err := NewRunner(tc.db, tc.clock, tc.emailTmpl, tc.emailBackend, c) - - assert.Equal(t, errors.Cause(err), tc.expectedErr, "error mismatch") - }) - } -} diff --git a/pkg/server/log/log.go b/pkg/server/log/log.go index 78f737bd..89eeb719 100644 --- a/pkg/server/log/log.go +++ b/pkg/server/log/log.go @@ -32,9 +32,19 @@ const ( fieldKeyTimestamp = "ts" fieldKeyUnixTimestamp = "ts_unix" - levelInfo = "info" - levelWarn = "warn" - levelError = "error" + // LevelDebug represents debug log level + LevelDebug = "debug" + // LevelInfo represents info log level + LevelInfo = "info" + // LevelWarn represents warn log level + LevelWarn = "warn" + // LevelError represents error log level + LevelError = "error" +) + +var ( + // currentLevel is the currently configured log level + currentLevel = LevelInfo ) // Fields represents a set of information to be included in the log @@ -58,19 +68,50 @@ func WithFields(fields Fields) Entry { return newEntry(fields) } +// SetLevel sets the global log level +func SetLevel(level string) { + currentLevel = level +} + +// levelPriority returns a numeric priority for comparison +func levelPriority(level string) int { + switch level { + case LevelDebug: + return 0 + case LevelInfo: + return 1 + case LevelWarn: + return 2 + case LevelError: + return 3 + default: + return 1 + } +} + +// shouldLog returns true if the given level should be logged based on currentLevel +func shouldLog(level string) bool { + return levelPriority(level) >= levelPriority(currentLevel) +} + +// Debug logs the given entry at a debug level +func (e Entry) Debug(msg string) { + e.write(LevelDebug, msg) +} + // Info logs the given entry at an info level func (e Entry) Info(msg string) { - e.write(levelInfo, msg) + e.write(LevelInfo, msg) } // Warn logs the given entry at a warning level func (e Entry) Warn(msg string) { - e.write(levelWarn, msg) + e.write(LevelWarn, msg) } // Error logs the given entry at an error level func (e Entry) Error(msg string) { - e.write(levelError, msg) + e.write(LevelError, msg) } // ErrorWrap logs the given entry with the error message annotated by the given message @@ -106,6 +147,10 @@ func (e Entry) formatJSON(level, msg string) []byte { } func (e Entry) write(level, msg string) { + if !shouldLog(level) { + return + } + serialized := e.formatJSON(level, msg) _, err := fmt.Fprintln(os.Stderr, string(serialized)) @@ -114,6 +159,11 @@ func (e Entry) write(level, msg string) { } } +// Debug logs a debug message without additional fields +func Debug(msg string) { + newEntry(Fields{}).Debug(msg) +} + // Info logs an info message without additional fields func Info(msg string) { newEntry(Fields{}).Info(msg) diff --git a/pkg/server/mailer/backend.go b/pkg/server/mailer/backend.go index eb6c3893..0abe51d8 100644 --- a/pkg/server/mailer/backend.go +++ b/pkg/server/mailer/backend.go @@ -19,11 +19,10 @@ package mailer import ( - "fmt" - "log" "os" "strconv" + "github.com/dnote/dnote/pkg/server/log" "github.com/pkg/errors" "gopkg.in/gomail.v2" ) @@ -36,9 +35,21 @@ type Backend interface { Queue(subject, from string, to []string, contentType, body string) error } -// SimpleBackendImplementation is an implementation of the Backend +// EmailDialer is an interface for sending email messages +type EmailDialer interface { + DialAndSend(m ...*gomail.Message) error +} + +// gomailDialer wraps gomail.Dialer to implement EmailDialer interface +type gomailDialer struct { + *gomail.Dialer +} + +// DefaultBackend is an implementation of the Backend // that sends an email without queueing. -type SimpleBackendImplementation struct { +type DefaultBackend struct { + Dialer EmailDialer + Enabled bool } type dialerParams struct { @@ -73,13 +84,31 @@ func getSMTPParams() (*dialerParams, error) { return p, nil } +// NewDefaultBackend creates a default backend +func NewDefaultBackend(enabled bool) (*DefaultBackend, error) { + p, err := getSMTPParams() + if err != nil { + return nil, err + } + + d := gomail.NewDialer(p.Host, p.Port, p.Username, p.Password) + + return &DefaultBackend{ + Dialer: &gomailDialer{Dialer: d}, + Enabled: enabled, + }, nil +} + // Queue is an implementation of Backend.Queue. -func (b *SimpleBackendImplementation) Queue(subject, from string, to []string, contentType, body string) error { - // If not production, never actually send an email - if os.Getenv("GO_ENV") != "PRODUCTION" { - log.Println("Not sending email because Dnote is not running in a production environment.") - log.Printf("Subject: %s, to: %s, from: %s", subject, to, from) - fmt.Println(body) +func (b *DefaultBackend) Queue(subject, from string, to []string, contentType, body string) error { + // If not enabled, just log the email + if !b.Enabled { + log.WithFields(log.Fields{ + "subject": subject, + "to": to, + "from": from, + "body": body, + }).Info("Not sending email because email backend is not configured.") return nil } @@ -89,13 +118,7 @@ func (b *SimpleBackendImplementation) Queue(subject, from string, to []string, c m.SetHeader("Subject", subject) m.SetBody(contentType, body) - p, err := getSMTPParams() - if err != nil { - return errors.Wrap(err, "getting dialer params") - } - - d := gomail.NewPlainDialer(p.Host, p.Port, p.Username, p.Password) - if err := d.DialAndSend(m); err != nil { + if err := b.Dialer.DialAndSend(m); err != nil { return errors.Wrap(err, "dialing and sending email") } diff --git a/pkg/server/mailer/backend_test.go b/pkg/server/mailer/backend_test.go new file mode 100644 index 00000000..5ef0a355 --- /dev/null +++ b/pkg/server/mailer/backend_test.go @@ -0,0 +1,107 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package mailer + +import ( + "testing" + + "gopkg.in/gomail.v2" +) + +type mockDialer struct { + sentMessages []*gomail.Message + err error +} + +func (m *mockDialer) DialAndSend(msgs ...*gomail.Message) error { + m.sentMessages = append(m.sentMessages, msgs...) + return m.err +} + +func TestDefaultBackendQueue(t *testing.T) { + t.Run("enabled sends email", func(t *testing.T) { + mock := &mockDialer{} + backend := &DefaultBackend{ + Dialer: mock, + Enabled: true, + } + + err := backend.Queue("Test Subject", "alice@example.com", []string{"bob@example.com"}, "text/plain", "Test body") + if err != nil { + t.Fatalf("Queue failed: %v", err) + } + + if len(mock.sentMessages) != 1 { + t.Errorf("expected 1 message sent, got %d", len(mock.sentMessages)) + } + }) + + t.Run("disabled does not send email", func(t *testing.T) { + mock := &mockDialer{} + backend := &DefaultBackend{ + Dialer: mock, + Enabled: false, + } + + err := backend.Queue("Test Subject", "alice@example.com", []string{"bob@example.com"}, "text/plain", "Test body") + if err != nil { + t.Fatalf("Queue failed: %v", err) + } + + if len(mock.sentMessages) != 0 { + t.Errorf("expected 0 messages sent when disabled, got %d", len(mock.sentMessages)) + } + }) +} + +func TestNewDefaultBackend(t *testing.T) { + t.Run("with all env vars set", func(t *testing.T) { + t.Setenv("SmtpHost", "smtp.example.com") + t.Setenv("SmtpPort", "587") + t.Setenv("SmtpUsername", "user@example.com") + t.Setenv("SmtpPassword", "secret") + + backend, err := NewDefaultBackend(true) + if err != nil { + t.Fatalf("NewDefaultBackend failed: %v", err) + } + + if backend.Enabled != true { + t.Errorf("expected Enabled to be true, got %v", backend.Enabled) + } + if backend.Dialer == nil { + t.Error("expected Dialer to be set") + } + }) + + t.Run("missing SMTP config returns error", func(t *testing.T) { + t.Setenv("SmtpHost", "") + t.Setenv("SmtpPort", "") + t.Setenv("SmtpUsername", "") + t.Setenv("SmtpPassword", "") + + _, err := NewDefaultBackend(true) + if err == nil { + t.Error("expected error when SMTP not configured") + } + if err != ErrSMTPNotConfigured { + t.Errorf("expected ErrSMTPNotConfigured, got %v", err) + } + }) +} diff --git a/pkg/server/mailer/mailer.go b/pkg/server/mailer/mailer.go index 786d703f..d02d8911 100644 --- a/pkg/server/mailer/mailer.go +++ b/pkg/server/mailer/mailer.go @@ -21,13 +21,11 @@ package mailer import ( "bytes" - "embed" "fmt" - htemplate "html/template" "io" ttemplate "text/template" - "github.com/aymerick/douceur/inliner" + "github.com/dnote/dnote/pkg/server/mailer/templates" "github.com/pkg/errors" ) @@ -40,13 +38,9 @@ var ( EmailTypeEmailVerification = "verify_email" // EmailTypeWelcome represents an welcome email EmailTypeWelcome = "welcome" - // EmailTypeInactiveReminder represents an inactivity reminder email - EmailTypeInactiveReminder = "inactive" ) var ( - // EmailKindHTML is the type of html email - EmailKindHTML = "text/html" // EmailKindText is the type of text email EmailKindText = "text/plain" ) @@ -60,9 +54,6 @@ type template interface { // Templates holds the parsed email templates type Templates map[string]template -//go:embed templates/src -var templateDir embed.FS - func getTemplateKey(name, kind string) string { return fmt.Sprintf("%s.%s", name, kind) } @@ -100,58 +91,21 @@ func NewTemplates() Templates { if err != nil { panic(errors.Wrap(err, "initializing password reset template")) } - inactiveReminderText, err := initTextTmpl(EmailTypeInactiveReminder) - if err != nil { - panic(errors.Wrap(err, "initializing password reset template")) - } T := Templates{} T.set(EmailTypeResetPassword, EmailKindText, passwordResetText) T.set(EmailTypeResetPasswordAlert, EmailKindText, passwordResetAlertText) T.set(EmailTypeEmailVerification, EmailKindText, verifyEmailText) T.set(EmailTypeWelcome, EmailKindText, welcomeText) - T.set(EmailTypeInactiveReminder, EmailKindText, inactiveReminderText) return T } -// initHTMLTmpl returns a template instance by parsing the template with the -// given name along with partials -func initHTMLTmpl(templateName string) (template, error) { - filename := fmt.Sprintf("templates/src/%s.html", templateName) - - content, err := templateDir.ReadFile(filename) - if err != nil { - return nil, errors.Wrap(err, "reading template") - } - headerContent, err := templateDir.ReadFile("templates/header.html") - if err != nil { - return nil, errors.Wrap(err, "reading header template") - } - footerContent, err := templateDir.ReadFile("templates/footer.html") - if err != nil { - return nil, errors.Wrap(err, "reading footer template") - } - - t := htemplate.New(templateName) - if _, err = t.Parse(string(content)); err != nil { - return nil, errors.Wrap(err, "parsing template") - } - if _, err = t.Parse(string(headerContent)); err != nil { - return nil, errors.Wrap(err, "parsing template") - } - if _, err = t.Parse(string(footerContent)); err != nil { - return nil, errors.Wrap(err, "parsing template") - } - - return t, nil -} - // initTextTmpl returns a template instance by parsing the template with the given name func initTextTmpl(templateName string) (template, error) { - filename := fmt.Sprintf("templates/src/%s.txt", templateName) + filename := fmt.Sprintf("%s.txt", templateName) - content, err := templateDir.ReadFile(filename) + content, err := templates.Files.ReadFile(filename) if err != nil { return nil, errors.Wrap(err, "reading template") } @@ -165,7 +119,7 @@ func initTextTmpl(templateName string) (template, error) { } // Execute executes the template with the given name with the givn data -func (tmpl Templates) Execute(name, kind string, data interface{}) (string, error) { +func (tmpl Templates) Execute(name, kind string, data any) (string, error) { t, err := tmpl.get(name, kind) if err != nil { return "", errors.Wrap(err, "getting template") @@ -176,15 +130,5 @@ func (tmpl Templates) Execute(name, kind string, data interface{}) (string, erro return "", errors.Wrap(err, "executing the template") } - // If HTML email, inline the CSS rules - if kind == EmailKindHTML { - html, err := inliner.Inline(buf.String()) - if err != nil { - return "", errors.Wrap(err, "inlining the css rules") - } - - return html, nil - } - return buf.String(), nil } diff --git a/pkg/server/mailer/mailer_test.go b/pkg/server/mailer/mailer_test.go index 6f24b4fb..df95b1f9 100644 --- a/pkg/server/mailer/mailer_test.go +++ b/pkg/server/mailer/mailer_test.go @@ -26,6 +26,26 @@ import ( "github.com/pkg/errors" ) +func TestAllTemplatesInitialized(t *testing.T) { + tmpl := NewTemplates() + + emailTypes := []string{ + EmailTypeResetPassword, + EmailTypeResetPasswordAlert, + EmailTypeEmailVerification, + EmailTypeWelcome, + } + + for _, emailType := range emailTypes { + t.Run(emailType, func(t *testing.T) { + _, err := tmpl.get(emailType, EmailKindText) + if err != nil { + t.Errorf("template %s not initialized: %v", emailType, err) + } + }) + } +} + func TestEmailVerificationEmail(t *testing.T) { testCases := []struct { token string @@ -101,3 +121,79 @@ func TestResetPasswordEmail(t *testing.T) { }) } } + +func TestWelcomeEmail(t *testing.T) { + testCases := []struct { + accountEmail string + webURL string + }{ + { + accountEmail: "test@example.com", + webURL: "http://localhost:3000", + }, + { + accountEmail: "user@example.org", + webURL: "http://localhost:3001", + }, + } + + tmpl := NewTemplates() + + for _, tc := range testCases { + t.Run(fmt.Sprintf("with WebURL %s and email %s", tc.webURL, tc.accountEmail), func(t *testing.T) { + dat := WelcomeTmplData{ + AccountEmail: tc.accountEmail, + WebURL: tc.webURL, + } + body, err := tmpl.Execute(EmailTypeWelcome, EmailKindText, dat) + if err != nil { + t.Fatal(errors.Wrap(err, "executing")) + } + + if ok := strings.Contains(body, tc.webURL); !ok { + t.Errorf("email body did not contain %s", tc.webURL) + } + if ok := strings.Contains(body, tc.accountEmail); !ok { + t.Errorf("email body did not contain %s", tc.accountEmail) + } + }) + } +} + +func TestResetPasswordAlertEmail(t *testing.T) { + testCases := []struct { + accountEmail string + webURL string + }{ + { + accountEmail: "test@example.com", + webURL: "http://localhost:3000", + }, + { + accountEmail: "user@example.org", + webURL: "http://localhost:3001", + }, + } + + tmpl := NewTemplates() + + for _, tc := range testCases { + t.Run(fmt.Sprintf("with WebURL %s and email %s", tc.webURL, tc.accountEmail), func(t *testing.T) { + dat := EmailResetPasswordAlertTmplData{ + AccountEmail: tc.accountEmail, + WebURL: tc.webURL, + } + body, err := tmpl.Execute(EmailTypeResetPasswordAlert, EmailKindText, dat) + if err != nil { + t.Fatal(errors.Wrap(err, "executing")) + } + + if ok := strings.Contains(body, tc.webURL); !ok { + t.Errorf("email body did not contain %s", tc.webURL) + } + if ok := strings.Contains(body, tc.accountEmail); !ok { + t.Errorf("email body did not contain %s", tc.accountEmail) + } + }) + } +} diff --git a/pkg/server/mailer/templates/.env.dev b/pkg/server/mailer/templates/.env.dev deleted file mode 100644 index 7808cb4a..00000000 --- a/pkg/server/mailer/templates/.env.dev +++ /dev/null @@ -1,12 +0,0 @@ -DBHost=localhost -DBPort=5433 -DBName=dnote -DBUser=postgres -DBPassword= - -SmtpUsername=mock-SmtpUsername -SmtpPassword=mock-SmtpPassword -SmtpHost=mock-SmtpHost - -WebURL=http://localhost:3000 -DisableRegistration=false diff --git a/pkg/server/mailer/templates/.gitignore b/pkg/server/mailer/templates/.gitignore deleted file mode 100644 index f8a26871..00000000 --- a/pkg/server/mailer/templates/.gitignore +++ /dev/null @@ -1 +0,0 @@ -templates diff --git a/pkg/server/mailer/templates/README.md b/pkg/server/mailer/templates/README.md deleted file mode 100644 index 9329442a..00000000 --- a/pkg/server/mailer/templates/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# templates - -Email templates - -* `/src` contains templates. - -## Development - -Run the server to develop templates locally. - -``` -./dev.sh -``` diff --git a/pkg/server/mailer/templates/dev.sh b/pkg/server/mailer/templates/dev.sh deleted file mode 100755 index 035220b1..00000000 --- a/pkg/server/mailer/templates/dev.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env bash -set -eux - -PID="" - -function cleanup { - if [ "$PID" != "" ]; then - kill "$PID" - fi -} -trap cleanup EXIT - -while true; do - go build main.go - ./main & - PID=$! - inotifywait -r -e modify . - kill $PID -done - - diff --git a/pkg/server/mailer/templates/main.go b/pkg/server/mailer/templates/main.go deleted file mode 100644 index 90ed940c..00000000 --- a/pkg/server/mailer/templates/main.go +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors - * - * This file is part of Dnote. - * - * Dnote is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dnote is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Dnote. If not, see . - */ - -package main - -import ( - "log" - "net/http" - - "github.com/dnote/dnote/pkg/server/config" - "github.com/dnote/dnote/pkg/server/database" - "github.com/dnote/dnote/pkg/server/mailer" - "gorm.io/gorm" - "github.com/joho/godotenv" - _ "github.com/lib/pq" -) - -func (c Context) passwordResetHandler(w http.ResponseWriter, r *http.Request) { - data := mailer.EmailResetPasswordTmplData{ - AccountEmail: "alice@example.com", - Token: "testToken", - WebURL: "http://localhost:3000", - } - body, err := c.Tmpl.Execute(mailer.EmailTypeResetPassword, mailer.EmailKindText, data) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Write([]byte(body)) -} - -func (c Context) passwordResetAlertHandler(w http.ResponseWriter, r *http.Request) { - data := mailer.EmailResetPasswordAlertTmplData{ - AccountEmail: "alice@example.com", - WebURL: "http://localhost:3000", - } - body, err := c.Tmpl.Execute(mailer.EmailTypeResetPasswordAlert, mailer.EmailKindText, data) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Write([]byte(body)) -} - -func (c Context) emailVerificationHandler(w http.ResponseWriter, r *http.Request) { - data := mailer.EmailVerificationTmplData{ - Token: "testToken", - WebURL: "http://localhost:3000", - } - body, err := c.Tmpl.Execute(mailer.EmailTypeEmailVerification, mailer.EmailKindText, data) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Write([]byte(body)) -} - -func (c Context) welcomeHandler(w http.ResponseWriter, r *http.Request) { - data := mailer.WelcomeTmplData{ - AccountEmail: "alice@example.com", - WebURL: "http://localhost:3000", - } - body, err := c.Tmpl.Execute(mailer.EmailTypeWelcome, mailer.EmailKindText, data) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Write([]byte(body)) -} - -func (c Context) inactiveHandler(w http.ResponseWriter, r *http.Request) { - data := mailer.InactiveReminderTmplData{ - SampleNoteUUID: "some-uuid", - WebURL: "http://localhost:3000", - Token: "some-random-token", - } - body, err := c.Tmpl.Execute(mailer.EmailTypeInactiveReminder, mailer.EmailKindText, data) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Write([]byte(body)) -} - -func (c Context) homeHandler(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("Email development server is running.")) -} - -func init() { - err := godotenv.Load(".env.dev") - if err != nil { - panic(err) - } -} - -// Context is a context holding global information -type Context struct { - DB *gorm.DB - Tmpl mailer.Templates -} - -func main() { - c := config.Load() - db := database.Open(c) - defer func() { - sqlDB, err := db.DB() - if err == nil { - sqlDB.Close() - } - }() - - log.Println("Email template development server running on http://127.0.0.1:2300") - - tmpl := mailer.NewTemplates() - ctx := Context{DB: db, Tmpl: tmpl} - - http.HandleFunc("/", ctx.homeHandler) - http.HandleFunc("/email-verification", ctx.emailVerificationHandler) - http.HandleFunc("/password-reset", ctx.passwordResetHandler) - http.HandleFunc("/password-reset-alert", ctx.passwordResetAlertHandler) - http.HandleFunc("/welcome", ctx.welcomeHandler) - http.HandleFunc("/inactive-reminder", ctx.inactiveHandler) - log.Fatal(http.ListenAndServe(":2300", nil)) -} diff --git a/pkg/server/mailer/templates/reset_password.txt b/pkg/server/mailer/templates/reset_password.txt new file mode 100644 index 00000000..9053a493 --- /dev/null +++ b/pkg/server/mailer/templates/reset_password.txt @@ -0,0 +1,5 @@ +You are receiving this because you requested to reset the password of the '{{ .AccountEmail }}' Dnote account. + +Please click on the following link, or paste this into your browser to complete the process: + + {{ .WebURL }}/password-reset/{{ .Token }} diff --git a/pkg/server/mailer/templates/src/reset_password_alert.txt b/pkg/server/mailer/templates/reset_password_alert.txt similarity index 50% rename from pkg/server/mailer/templates/src/reset_password_alert.txt rename to pkg/server/mailer/templates/reset_password_alert.txt index 3aa9bdd6..16957375 100644 --- a/pkg/server/mailer/templates/src/reset_password_alert.txt +++ b/pkg/server/mailer/templates/reset_password_alert.txt @@ -2,7 +2,7 @@ Hi, This email is to notify you that the password for your Dnote account "{{ .AccountEmail }}" has changed. -If you did not initiate this password change, please notify us by replying, and reset your password at {{ .WebURL }}/password-reset +If you did not initiate this password change, reset your password at {{ .WebURL }}/password-reset. Thanks. diff --git a/pkg/server/mailer/templates/scripts/run.sh b/pkg/server/mailer/templates/scripts/run.sh deleted file mode 100755 index fd8e8ac5..00000000 --- a/pkg/server/mailer/templates/scripts/run.sh +++ /dev/null @@ -1 +0,0 @@ -CompileDaemon -directory=. -command="./templates" -include="*.html" diff --git a/pkg/server/mailer/templates/src/inactive.txt b/pkg/server/mailer/templates/src/inactive.txt deleted file mode 100644 index b6f4d508..00000000 --- a/pkg/server/mailer/templates/src/inactive.txt +++ /dev/null @@ -1,9 +0,0 @@ -Hi, nothing has been added to your Dnote for some time. - -What about revisiting one of your previous notes? {{ .WebURL }}/notes/{{ .SampleNoteUUID }} - -You can add new notes at {{ .WebURL }}/new or using Dnote apps. - -- Dnote team - -UNSUBSCRIBE: {{ .WebURL }}/settings/notifications?token={{ .Token }} diff --git a/pkg/server/mailer/templates/src/reset_password.txt b/pkg/server/mailer/templates/src/reset_password.txt deleted file mode 100644 index 3bc34850..00000000 --- a/pkg/server/mailer/templates/src/reset_password.txt +++ /dev/null @@ -1,9 +0,0 @@ -You are receiving this because you (or someone else) requested to reset the password of the '{{ .AccountEmail }}' Dnote account. - -Please click on the following link, or paste this into your browser to complete the process: - - {{ .WebURL }}/password-reset/{{ .Token }} - -You can reply to this message, if you have questions. - -- Dnote team diff --git a/pkg/server/app/main_test.go b/pkg/server/mailer/templates/templates.go similarity index 78% rename from pkg/server/app/main_test.go rename to pkg/server/mailer/templates/templates.go index d757da42..c59d3a14 100644 --- a/pkg/server/app/main_test.go +++ b/pkg/server/mailer/templates/templates.go @@ -16,20 +16,10 @@ * along with Dnote. If not, see . */ -package app +// Package mailer provides a functionality to send emails +package templates -import ( - "os" - "testing" +import "embed" - "github.com/dnote/dnote/pkg/server/testutils" -) - -func TestMain(m *testing.M) { - testutils.InitTestDB() - - code := m.Run() - testutils.ClearData(testutils.DB) - - os.Exit(code) -} +//go:embed *.txt +var Files embed.FS diff --git a/pkg/server/mailer/templates/src/verify_email.txt b/pkg/server/mailer/templates/verify_email.txt similarity index 72% rename from pkg/server/mailer/templates/src/verify_email.txt rename to pkg/server/mailer/templates/verify_email.txt index a85ab705..c21af88d 100644 --- a/pkg/server/mailer/templates/src/verify_email.txt +++ b/pkg/server/mailer/templates/verify_email.txt @@ -1,9 +1,5 @@ -Hi. +Hi, Welcome to Dnote! To verify your email, visit the following link: {{ .WebURL }}/verify-email/{{ .Token }} - -Thanks for using Dnote. - -- Dnote team diff --git a/pkg/server/mailer/templates/src/welcome.txt b/pkg/server/mailer/templates/welcome.txt similarity index 83% rename from pkg/server/mailer/templates/src/welcome.txt rename to pkg/server/mailer/templates/welcome.txt index 7a33207a..72d0fdf0 100644 --- a/pkg/server/mailer/templates/src/welcome.txt +++ b/pkg/server/mailer/templates/welcome.txt @@ -10,7 +10,3 @@ If you ever forget your password, you can reset it at {{ .WebURL }}/password-res SOURCE CODE Dnote is open source and you can see the source code at https://github.com/dnote/dnote - -Feel free to reply anytime. Thanks for using Dnote. - -- Dnote team diff --git a/pkg/server/mailer/tokens.go b/pkg/server/mailer/tokens.go index 7d78725f..0f751a4e 100644 --- a/pkg/server/mailer/tokens.go +++ b/pkg/server/mailer/tokens.go @@ -21,19 +21,17 @@ package mailer import ( "crypto/rand" "encoding/base64" - "errors" "github.com/dnote/dnote/pkg/server/database" - pkgErrors "github.com/pkg/errors" + "github.com/pkg/errors" "gorm.io/gorm" ) func generateRandomToken(bits int) (string, error) { b := make([]byte, bits) - _, err := rand.Read(b) - if err != nil { - return "", pkgErrors.Wrap(err, "generating random bytes") + if _, err := rand.Read(b); err != nil { + return "", errors.Wrap(err, "generating random bytes") } return base64.URLEncoding.EncodeToString(b), nil @@ -49,7 +47,7 @@ func GetToken(db *gorm.DB, userID int, kind string) (database.Token, error) { tokenVal, genErr := generateRandomToken(16) if genErr != nil { - return tok, pkgErrors.Wrap(genErr, "generating token value") + return tok, errors.Wrap(genErr, "generating token value") } if errors.Is(err, gorm.ErrRecordNotFound) { @@ -59,12 +57,12 @@ func GetToken(db *gorm.DB, userID int, kind string) (database.Token, error) { Value: tokenVal, } if err := db.Save(&tok).Error; err != nil { - return tok, pkgErrors.Wrap(err, "saving token") + return tok, errors.Wrap(err, "saving token") } return tok, nil } else if err != nil { - return tok, pkgErrors.Wrap(err, "finding token") + return tok, errors.Wrap(err, "finding token") } return tok, nil diff --git a/pkg/server/mailer/tokens_test.go b/pkg/server/mailer/tokens_test.go new file mode 100644 index 00000000..72a85fc6 --- /dev/null +++ b/pkg/server/mailer/tokens_test.go @@ -0,0 +1,83 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package mailer + +import ( + "testing" + + "github.com/dnote/dnote/pkg/server/database" + "github.com/dnote/dnote/pkg/server/testutils" +) + +func TestGetToken(t *testing.T) { + db := testutils.InitMemoryDB(t) + + userID := 1 + tokenType := "email_verification" + + t.Run("creates new token", func(t *testing.T) { + token, err := GetToken(db, userID, tokenType) + if err != nil { + t.Fatalf("GetToken failed: %v", err) + } + + if token.UserID != userID { + t.Errorf("expected UserID %d, got %d", userID, token.UserID) + } + if token.Type != tokenType { + t.Errorf("expected Type %s, got %s", tokenType, token.Type) + } + if token.Value == "" { + t.Error("expected non-empty token Value") + } + if token.UsedAt != nil { + t.Error("expected UsedAt to be nil for new token") + } + }) + + t.Run("reuses unused token", func(t *testing.T) { + // Get token again - should return the same one + token2, err := GetToken(db, userID, tokenType) + if err != nil { + t.Fatalf("second GetToken failed: %v", err) + } + + // Get first token to compare + var token1 database.Token + if err := db.Where("user_id = ? AND type = ?", userID, tokenType).First(&token1).Error; err != nil { + t.Fatalf("failed to get first token: %v", err) + } + + if token1.ID != token2.ID { + t.Errorf("expected same token ID %d, got %d", token1.ID, token2.ID) + } + if token1.Value != token2.Value { + t.Errorf("expected same token Value %s, got %s", token1.Value, token2.Value) + } + + // Verify only one token exists in database + var count int64 + if err := db.Model(&database.Token{}).Where("user_id = ? AND type = ?", userID, tokenType).Count(&count).Error; err != nil { + t.Fatalf("failed to count tokens: %v", err) + } + if count != 1 { + t.Errorf("expected 1 token in database, got %d", count) + } + }) +} diff --git a/pkg/server/mailer/types.go b/pkg/server/mailer/types.go index da9a448c..3a371911 100644 --- a/pkg/server/mailer/types.go +++ b/pkg/server/mailer/types.go @@ -42,16 +42,3 @@ type WelcomeTmplData struct { AccountEmail string WebURL string } - -// InactiveReminderTmplData is a template data for welcome emails -type InactiveReminderTmplData struct { - SampleNoteUUID string - WebURL string - Token string -} - -// EmailTypeSubscriptionConfirmationTmplData is a template data for reset password emails -type EmailTypeSubscriptionConfirmationTmplData struct { - AccountEmail string - WebURL string -} diff --git a/pkg/server/main.go b/pkg/server/main.go index d8df31e0..4e36b96d 100644 --- a/pkg/server/main.go +++ b/pkg/server/main.go @@ -21,8 +21,8 @@ package main import ( "flag" "fmt" - "log" "net/http" + "os" "github.com/dnote/dnote/pkg/clock" "github.com/dnote/dnote/pkg/server/app" @@ -30,54 +30,81 @@ import ( "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/controllers" "github.com/dnote/dnote/pkg/server/database" - "github.com/dnote/dnote/pkg/server/job" + "github.com/dnote/dnote/pkg/server/log" "github.com/dnote/dnote/pkg/server/mailer" - "gorm.io/driver/postgres" - "gorm.io/gorm" - "github.com/pkg/errors" + "gorm.io/gorm" ) -var port = flag.String("port", "3000", "port to connect to") - -func initDB(c config.Config) *gorm.DB { - db, err := gorm.Open(postgres.Open(c.DB.GetConnectionStr()), &gorm.Config{}) - if err != nil { - panic(errors.Wrap(err, "opening database connection")) - } +func initDB(dbPath string) *gorm.DB { + db := database.Open(dbPath) database.InitSchema(db) + database.Migrate(db) return db } func initApp(cfg config.Config) app.App { - db := initDB(cfg) + db := initDB(cfg.DBPath) + + emailBackend, err := mailer.NewDefaultBackend(cfg.IsProd()) + if err != nil { + emailBackend = &mailer.DefaultBackend{Enabled: false} + } else { + log.Info("Email backend configured") + } return app.App{ - DB: db, - Clock: clock.New(), - EmailTemplates: mailer.NewTemplates(), - EmailBackend: &mailer.SimpleBackendImplementation{}, - Config: cfg, - HTTP500Page: cfg.HTTP500Page, + DB: db, + Clock: clock.New(), + EmailTemplates: mailer.NewTemplates(), + EmailBackend: emailBackend, + HTTP500Page: cfg.HTTP500Page, + AppEnv: cfg.AppEnv, + WebURL: cfg.WebURL, + DisableRegistration: cfg.DisableRegistration, + Port: cfg.Port, + DBPath: cfg.DBPath, + AssetBaseURL: cfg.AssetBaseURL, } } -func runJob(a app.App) error { - runner, err := job.NewRunner(a.DB, a.Clock, a.EmailTemplates, a.EmailBackend, a.Config) +func startCmd(args []string) { + startFlags := flag.NewFlagSet("start", flag.ExitOnError) + startFlags.Usage = func() { + fmt.Printf(`Usage: + dnote-server start [flags] + +Flags: +`) + startFlags.PrintDefaults() + } + + appEnv := startFlags.String("appEnv", "", "Application environment (env: APP_ENV, default: PRODUCTION)") + port := startFlags.String("port", "", "Server port (env: PORT, default: 3000)") + webURL := startFlags.String("webUrl", "", "Full URL to server without trailing slash (env: WebURL, example: https://example.com)") + dbPath := startFlags.String("dbPath", "", "Path to SQLite database file (env: DBPath, default: $XDG_DATA_HOME/dnote/server.db)") + disableRegistration := startFlags.Bool("disableRegistration", false, "Disable user registration (env: DisableRegistration, default: false)") + logLevel := startFlags.String("logLevel", "", "Log level: debug, info, warn, or error (env: LOG_LEVEL, default: info)") + + startFlags.Parse(args) + + cfg, err := config.New(config.Params{ + AppEnv: *appEnv, + Port: *port, + WebURL: *webURL, + DBPath: *dbPath, + DisableRegistration: *disableRegistration, + LogLevel: *logLevel, + }) if err != nil { - return errors.Wrap(err, "getting a job runner") - } - if err := runner.Do(); err != nil { - return errors.Wrap(err, "running job") + fmt.Printf("Error: %s\n\n", err) + startFlags.Usage() + os.Exit(1) } - return nil -} - -func startCmd() { - cfg := config.Load() - cfg.SetAssetBaseURL("/static") + // Set log level + log.SetLevel(cfg.LogLevel) app := initApp(cfg) defer func() { @@ -87,13 +114,6 @@ func startCmd() { } }() - if err := database.Migrate(app.DB); err != nil { - panic(errors.Wrap(err, "running migrations")) - } - if err := runJob(app); err != nil { - panic(errors.Wrap(err, "running job")) - } - ctl := controllers.New(&app) rc := controllers.RouteConfig{ WebRoutes: controllers.NewWebRoutes(&app, ctl), @@ -106,8 +126,15 @@ func startCmd() { panic(errors.Wrap(err, "initializing router")) } - log.Printf("Dnote version %s is running on port %s", buildinfo.Version, *port) - log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%s", *port), r)) + log.WithFields(log.Fields{ + "version": buildinfo.Version, + "port": cfg.Port, + }).Info("Dnote server starting") + + if err := http.ListenAndServe(fmt.Sprintf(":%s", cfg.Port), r); err != nil { + log.ErrorWrap(err, "server failed") + os.Exit(1) + } } func versionCmd() { @@ -115,29 +142,33 @@ func versionCmd() { } func rootCmd() { - fmt.Printf(`Dnote server - a simple personal knowledge base + fmt.Printf(`Dnote server - a simple command line notebook Usage: - dnote-server [command] + dnote-server [command] [flags] Available commands: - start: Start the server + start: Start the server (use 'dnote-server start --help' for flags) version: Print the version `) } func main() { - flag.Parse() - cmd := flag.Arg(0) + if len(os.Args) < 2 { + rootCmd() + return + } + + cmd := os.Args[1] switch cmd { - case "": - rootCmd() case "start": - startCmd() + startCmd(os.Args[2:]) case "version": versionCmd() default: - fmt.Printf("Unknown command %s", cmd) + fmt.Printf("Unknown command %s\n", cmd) + rootCmd() + os.Exit(1) } } diff --git a/pkg/server/middleware/auth.go b/pkg/server/middleware/auth.go index 28af8760..984079e4 100644 --- a/pkg/server/middleware/auth.go +++ b/pkg/server/middleware/auth.go @@ -22,19 +22,17 @@ import ( "errors" "net/http" "net/url" - "strings" "time" - "github.com/dnote/dnote/pkg/server/app" "github.com/dnote/dnote/pkg/server/context" "github.com/dnote/dnote/pkg/server/database" "github.com/dnote/dnote/pkg/server/helpers" "github.com/dnote/dnote/pkg/server/log" - "gorm.io/gorm" pkgErrors "github.com/pkg/errors" + "gorm.io/gorm" ) -func authWithToken(db *gorm.DB, r *http.Request, tokenType string, p *AuthParams) (database.User, database.Token, bool, error) { +func authWithToken(db *gorm.DB, r *http.Request, tokenType string) (database.User, database.Token, bool, error) { var user database.User var token database.Token @@ -62,32 +60,17 @@ func authWithToken(db *gorm.DB, r *http.Request, tokenType string, p *AuthParams return user, token, true, nil } -// Cors allows browser extensions to load resources -func Cors(next http.HandlerFunc) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - - // Allow browser extensions - if strings.HasPrefix(origin, "moz-extension://") || strings.HasPrefix(origin, "chrome-extension://") { - w.Header().Set("Access-Control-Allow-Origin", origin) - } - - next.ServeHTTP(w, r) - }) -} - // AuthParams is the params for the authentication middleware type AuthParams struct { - ProOnly bool RedirectGuestsToLogin bool } // Auth is an authentication middleware -func Auth(a *app.App, next http.HandlerFunc, p *AuthParams) http.HandlerFunc { - next = WithAccount(a, next) +func Auth(db *gorm.DB, next http.HandlerFunc, p *AuthParams) http.HandlerFunc { + next = WithAccount(db, next) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user, ok, err := AuthWithSession(a.DB, r) + user, ok, err := AuthWithSession(db, r) if !ok { if p != nil && p.RedirectGuestsToLogin { @@ -107,25 +90,18 @@ func Auth(a *app.App, next http.HandlerFunc, p *AuthParams) http.HandlerFunc { return } - if p != nil && p.ProOnly { - if !user.Cloud { - RespondForbidden(w) - return - } - } - ctx := context.WithUser(r.Context(), &user) next.ServeHTTP(w, r.WithContext(ctx)) }) } -func WithAccount(a *app.App, next http.HandlerFunc) http.HandlerFunc { +func WithAccount(db *gorm.DB, next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := context.User(r.Context()) var account database.Account - if err := a.DB.Where("user_id = ?", user.ID).First(&account).Error; err != nil { + if err := db.Where("user_id = ?", user.ID).First(&account).Error; err != nil { DoError(w, "finding account", err, http.StatusInternalServerError) return } @@ -137,9 +113,9 @@ func WithAccount(a *app.App, next http.HandlerFunc) http.HandlerFunc { } // TokenAuth is an authentication middleware with token -func TokenAuth(a *app.App, next http.HandlerFunc, tokenType string, p *AuthParams) http.HandlerFunc { +func TokenAuth(db *gorm.DB, next http.HandlerFunc, tokenType string, p *AuthParams) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user, token, ok, err := authWithToken(a.DB, r, tokenType, p) + user, token, ok, err := authWithToken(db, r, tokenType) if err != nil { // log the error and continue log.ErrorWrap(err, "authenticating with token") @@ -151,7 +127,7 @@ func TokenAuth(a *app.App, next http.HandlerFunc, tokenType string, p *AuthParam ctx = context.WithToken(ctx, &token) } else { // If token-based auth fails, fall back to session-based auth - user, ok, err = AuthWithSession(a.DB, r) + user, ok, err = AuthWithSession(db, r) if err != nil { DoError(w, "authenticating with session", err, http.StatusInternalServerError) return @@ -163,13 +139,6 @@ func TokenAuth(a *app.App, next http.HandlerFunc, tokenType string, p *AuthParam } } - if p != nil && p.ProOnly { - if !user.Cloud { - RespondForbidden(w) - return - } - } - ctx = context.WithUser(ctx, &user) next.ServeHTTP(w, r.WithContext(ctx)) }) @@ -211,9 +180,9 @@ func AuthWithSession(db *gorm.DB, r *http.Request) (database.User, bool, error) return user, true, nil } -func GuestOnly(a *app.App, next http.HandlerFunc) http.HandlerFunc { +func GuestOnly(db *gorm.DB, next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, ok, err := AuthWithSession(a.DB, r) + _, ok, err := AuthWithSession(db, r) if err != nil { // log the error and continue log.ErrorWrap(err, "authenticating with session") diff --git a/pkg/server/middleware/auth_test.go b/pkg/server/middleware/auth_test.go new file mode 100644 index 00000000..8451ae5d --- /dev/null +++ b/pkg/server/middleware/auth_test.go @@ -0,0 +1,235 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/dnote/dnote/pkg/assert" + "github.com/dnote/dnote/pkg/server/database" + "github.com/dnote/dnote/pkg/server/testutils" +) + +func TestGuestOnly(t *testing.T) { + db := testutils.InitMemoryDB(t) + + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + server := httptest.NewServer(GuestOnly(db, handler)) + defer server.Close() + + t.Run("guest", func(t *testing.T) { + req := testutils.MakeReq(server.URL, "GET", "/", "") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") + }) + + t.Run("logged in", func(t *testing.T) { + user := testutils.SetupUserData(db) + req := testutils.MakeReq(server.URL, "GET", "/", "") + res := testutils.HTTPAuthDo(t, db, req, user) + + assert.Equal(t, res.StatusCode, http.StatusFound, "status code mismatch") + assert.Equal(t, res.Header.Get("Location"), "/", "location mismatch") + }) + + t.Run("error getting credential", func(t *testing.T) { + req := testutils.MakeReq(server.URL, "GET", "/", "") + req.Header.Set("Authorization", "InvalidFormat") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") + }) +} + +func TestAuth(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + testutils.SetupAccountData(db, user, "alice@test.com", "pass1234") + + session := database.Session{ + Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", + UserID: user.ID, + ExpiresAt: time.Now().Add(time.Hour * 24), + } + testutils.MustExec(t, db.Save(&session), "preparing session") + expiredSession := database.Session{ + Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=", + UserID: user.ID, + ExpiresAt: time.Now().Add(-time.Hour * 24), + } + testutils.MustExec(t, db.Save(&expiredSession), "preparing expired session") + + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + t.Run("valid session with header", func(t *testing.T) { + server := httptest.NewServer(Auth(db, handler, nil)) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/", "") + req.Header.Set("Authorization", "Bearer "+session.Key) + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") + }) + + t.Run("expired session with header", func(t *testing.T) { + server := httptest.NewServer(Auth(db, handler, nil)) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/", "") + req.Header.Set("Authorization", "Bearer "+expiredSession.Key) + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") + }) + + t.Run("invalid session with header", func(t *testing.T) { + server := httptest.NewServer(Auth(db, handler, nil)) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/", "") + req.Header.Set("Authorization", "Bearer someInvalidSessionKey=") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") + }) + + t.Run("valid session with cookie", func(t *testing.T) { + server := httptest.NewServer(Auth(db, handler, nil)) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/", "") + req.AddCookie(&http.Cookie{ + Name: "id", + Value: session.Key, + HttpOnly: true, + }) + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") + }) + + t.Run("expired session with cookie", func(t *testing.T) { + server := httptest.NewServer(Auth(db, handler, nil)) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/", "") + req.AddCookie(&http.Cookie{ + Name: "id", + Value: expiredSession.Key, + HttpOnly: true, + }) + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") + }) + + t.Run("no auth", func(t *testing.T) { + server := httptest.NewServer(Auth(db, handler, nil)) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/", "") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") + }) + + t.Run("redirect guests to login", func(t *testing.T) { + server := httptest.NewServer(Auth(db, handler, &AuthParams{RedirectGuestsToLogin: true})) + defer server.Close() + + req := testutils.MakeReq(server.URL, "GET", "/settings", "") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusFound, "status code mismatch") + assert.Equal(t, res.Header.Get("Location"), "/login?referrer=%2Fsettings", "location mismatch") + }) +} + +func TestTokenAuth(t *testing.T) { + db := testutils.InitMemoryDB(t) + + user := testutils.SetupUserData(db) + tok := database.Token{ + UserID: user.ID, + Type: database.TokenTypeEmailVerification, + Value: "xpwFnc0MdllFUePDq9DLeQ==", + } + testutils.MustExec(t, db.Save(&tok), "preparing token") + session := database.Session{ + Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", + UserID: user.ID, + ExpiresAt: time.Now().Add(time.Hour * 24), + } + testutils.MustExec(t, db.Save(&session), "preparing session") + + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + server := httptest.NewServer(TokenAuth(db, handler, database.TokenTypeEmailVerification, nil)) + defer server.Close() + + t.Run("with token", func(t *testing.T) { + req := testutils.MakeReq(server.URL, "GET", "/?token=xpwFnc0MdllFUePDq9DLeQ==", "") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") + }) + + t.Run("with invalid token", func(t *testing.T) { + req := testutils.MakeReq(server.URL, "GET", "/?token=someRandomToken==", "") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") + }) + + t.Run("with session header", func(t *testing.T) { + req := testutils.MakeReq(server.URL, "GET", "/", "") + req.Header.Set("Authorization", "Bearer "+session.Key) + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") + }) + + t.Run("with invalid session", func(t *testing.T) { + req := testutils.MakeReq(server.URL, "GET", "/", "") + req.Header.Set("Authorization", "Bearer someInvalidSessionKey=") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") + }) + + t.Run("without anything", func(t *testing.T) { + req := testutils.MakeReq(server.URL, "GET", "/", "") + res := testutils.HTTPDo(t, req) + + assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") + }) +} diff --git a/pkg/server/middleware/helpers.go b/pkg/server/middleware/helpers.go index e1059db8..f43d0abb 100644 --- a/pkg/server/middleware/helpers.go +++ b/pkg/server/middleware/helpers.go @@ -92,7 +92,6 @@ func DoError(w http.ResponseWriter, msg string, err error, statusCode int) { // NotSupported is the handler for the route that is no longer supported func NotSupported(w http.ResponseWriter, r *http.Request) { http.Error(w, "API version is not supported. Please upgrade your client.", http.StatusGone) - return } // getSessionKeyFromCookie reads and returns a session key from the cookie sent by the diff --git a/pkg/server/middleware/helpers_test.go b/pkg/server/middleware/helpers_test.go index 1368dd70..623ec818 100644 --- a/pkg/server/middleware/helpers_test.go +++ b/pkg/server/middleware/helpers_test.go @@ -19,16 +19,10 @@ package middleware import ( - "fmt" "net/http" - "net/http/httptest" "testing" - "time" "github.com/dnote/dnote/pkg/assert" - "github.com/dnote/dnote/pkg/server/app" - "github.com/dnote/dnote/pkg/server/database" - "github.com/dnote/dnote/pkg/server/testutils" "github.com/pkg/errors" ) @@ -180,521 +174,3 @@ func TestGetCredential(t *testing.T) { } } -func TestAuthMiddleware(t *testing.T) { - defer testutils.ClearData(testutils.DB) - - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - - session := database.Session{ - Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", - UserID: user.ID, - ExpiresAt: time.Now().Add(time.Hour * 24), - } - testutils.MustExec(t, testutils.DB.Save(&session), "preparing session") - session2 := database.Session{ - Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=", - UserID: user.ID, - ExpiresAt: time.Now().Add(-time.Hour * 24), - } - testutils.MustExec(t, testutils.DB.Save(&session2), "preparing session") - - handler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - } - a := &app.App{DB: testutils.DB} - server := httptest.NewServer(Auth(a, handler, nil)) - defer server.Close() - - t.Run("with header", func(t *testing.T) { - testCases := []struct { - header string - expectedStatus int - }{ - { - header: fmt.Sprintf("Bearer %s", session.Key), - expectedStatus: http.StatusOK, - }, - { - header: fmt.Sprintf("Bearer %s", session2.Key), - expectedStatus: http.StatusUnauthorized, - }, - { - header: fmt.Sprintf("Bearer someInvalidSessionKey="), - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.header, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - req.Header.Set("Authorization", tc.header) - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("with cookie", func(t *testing.T) { - testCases := []struct { - cookie *http.Cookie - expectedStatus int - }{ - { - cookie: &http.Cookie{ - Name: "id", - Value: session.Key, - HttpOnly: true, - }, - expectedStatus: http.StatusOK, - }, - { - cookie: &http.Cookie{ - Name: "id", - Value: session2.Key, - HttpOnly: true, - }, - expectedStatus: http.StatusUnauthorized, - }, - { - cookie: &http.Cookie{ - Name: "id", - Value: "someInvalidSessionKey=", - HttpOnly: true, - }, - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.cookie.Value, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - req.AddCookie(tc.cookie) - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("without anything", func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") - }) -} - -func TestAuthMiddleware_ProOnly(t *testing.T) { - defer testutils.ClearData(testutils.DB) - - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session") - session := database.Session{ - Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", - UserID: user.ID, - ExpiresAt: time.Now().Add(time.Hour * 24), - } - testutils.MustExec(t, testutils.DB.Save(&session), "preparing session") - - handler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - } - - a := &app.App{DB: testutils.DB} - server := httptest.NewServer(Auth(a, handler, &AuthParams{ - ProOnly: true, - })) - - defer server.Close() - - t.Run("with header", func(t *testing.T) { - testCases := []struct { - header string - expectedStatus int - }{ - { - header: fmt.Sprintf("Bearer %s", session.Key), - expectedStatus: http.StatusForbidden, - }, - { - header: fmt.Sprintf("Bearer someInvalidSessionKey="), - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.header, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - req.Header.Set("Authorization", tc.header) - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("with cookie", func(t *testing.T) { - testCases := []struct { - cookie *http.Cookie - expectedStatus int - }{ - { - cookie: &http.Cookie{ - Name: "id", - Value: session.Key, - HttpOnly: true, - }, - expectedStatus: http.StatusForbidden, - }, - { - cookie: &http.Cookie{ - Name: "id", - Value: "someInvalidSessionKey=", - HttpOnly: true, - }, - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.cookie.Value, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - req.AddCookie(tc.cookie) - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) -} - -func TestAuthMiddleware_RedirectGuestsToLogin(t *testing.T) { - defer testutils.ClearData(testutils.DB) - - handler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - } - - a := &app.App{DB: testutils.DB} - server := httptest.NewServer(Auth(a, handler, &AuthParams{ - RedirectGuestsToLogin: true, - })) - - defer server.Close() - - t.Run("guest", func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, http.StatusFound, "status code mismatch") - assert.Equal(t, res.Header.Get("Location"), "/login?referrer=%2F", "location header mismatch") - }) - - t.Run("logged in user", func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - - user := testutils.SetupUserData() - testutils.SetupAccountData(user, "alice@test.com", "pass1234") - - testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session") - session := database.Session{ - Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", - UserID: user.ID, - ExpiresAt: time.Now().Add(time.Hour * 24), - } - testutils.MustExec(t, testutils.DB.Save(&session), "preparing session") - - // execute - res := testutils.HTTPAuthDo(t, req, user) - req.Header.Set("Authorization", session.Key) - - // test - assert.Equal(t, res.StatusCode, http.StatusOK, "status code mismatch") - assert.Equal(t, res.Header.Get("Location"), "", "location header mismatch") - }) - -} - -func TestTokenAuthMiddleWare(t *testing.T) { - defer testutils.ClearData(testutils.DB) - - user := testutils.SetupUserData() - tok := database.Token{ - UserID: user.ID, - Type: database.TokenTypeEmailPreference, - Value: "xpwFnc0MdllFUePDq9DLeQ==", - } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") - session := database.Session{ - Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", - UserID: user.ID, - ExpiresAt: time.Now().Add(time.Hour * 24), - } - testutils.MustExec(t, testutils.DB.Save(&session), "preparing session") - - handler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - } - - a := &app.App{DB: testutils.DB} - server := httptest.NewServer(TokenAuth(a, handler, database.TokenTypeEmailPreference, nil)) - defer server.Close() - - t.Run("with token", func(t *testing.T) { - testCases := []struct { - token string - expectedStatus int - }{ - { - token: "xpwFnc0MdllFUePDq9DLeQ==", - expectedStatus: http.StatusOK, - }, - { - token: "someRandomToken==", - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.token, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/?token=%s", tc.token), "") - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("with session header", func(t *testing.T) { - testCases := []struct { - header string - expectedStatus int - }{ - { - header: fmt.Sprintf("Bearer %s", session.Key), - expectedStatus: http.StatusOK, - }, - { - header: fmt.Sprintf("Bearer someInvalidSessionKey="), - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.header, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - req.Header.Set("Authorization", tc.header) - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("with session cookie", func(t *testing.T) { - testCases := []struct { - cookie *http.Cookie - expectedStatus int - }{ - { - cookie: &http.Cookie{ - Name: "id", - Value: session.Key, - HttpOnly: true, - }, - expectedStatus: http.StatusOK, - }, - { - cookie: &http.Cookie{ - Name: "id", - Value: "someInvalidSessionKey=", - HttpOnly: true, - }, - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.cookie.Value, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - req.AddCookie(tc.cookie) - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("without anything", func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") - }) -} - -func TestTokenAuthMiddleWare_ProOnly(t *testing.T) { - defer testutils.ClearData(testutils.DB) - - user := testutils.SetupUserData() - testutils.MustExec(t, testutils.DB.Model(&user).Update("cloud", false), "preparing session") - tok := database.Token{ - UserID: user.ID, - Type: database.TokenTypeEmailPreference, - Value: "xpwFnc0MdllFUePDq9DLeQ==", - } - testutils.MustExec(t, testutils.DB.Save(&tok), "preparing token") - session := database.Session{ - Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", - UserID: user.ID, - ExpiresAt: time.Now().Add(time.Hour * 24), - } - testutils.MustExec(t, testutils.DB.Save(&session), "preparing session") - - handler := func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - } - - a := &app.App{DB: testutils.DB} - server := httptest.NewServer(TokenAuth(a, handler, database.TokenTypeEmailPreference, &AuthParams{ - ProOnly: true, - })) - - defer server.Close() - - t.Run("with token", func(t *testing.T) { - testCases := []struct { - token string - expectedStatus int - }{ - { - token: "xpwFnc0MdllFUePDq9DLeQ==", - expectedStatus: http.StatusForbidden, - }, - { - token: "someRandomToken==", - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.token, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", fmt.Sprintf("/?token=%s", tc.token), "") - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("with session header", func(t *testing.T) { - testCases := []struct { - header string - expectedStatus int - }{ - { - header: fmt.Sprintf("Bearer %s", session.Key), - expectedStatus: http.StatusForbidden, - }, - { - header: fmt.Sprintf("Bearer someInvalidSessionKey="), - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.header, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - req.Header.Set("Authorization", tc.header) - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("with session cookie", func(t *testing.T) { - testCases := []struct { - cookie *http.Cookie - expectedStatus int - }{ - { - cookie: &http.Cookie{ - Name: "id", - Value: session.Key, - HttpOnly: true, - }, - expectedStatus: http.StatusForbidden, - }, - { - cookie: &http.Cookie{ - Name: "id", - Value: "someInvalidSessionKey=", - HttpOnly: true, - }, - expectedStatus: http.StatusUnauthorized, - }, - } - - for _, tc := range testCases { - t.Run(tc.cookie.Value, func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - req.AddCookie(tc.cookie) - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, tc.expectedStatus, "status code mismatch") - }) - } - }) - - t.Run("without anything", func(t *testing.T) { - req := testutils.MakeReq(server.URL, "GET", "/", "") - - // execute - res := testutils.HTTPDo(t, req) - - // test - assert.Equal(t, res.StatusCode, http.StatusUnauthorized, "status code mismatch") - }) -} diff --git a/pkg/server/middleware/limit.go b/pkg/server/middleware/limit.go index 64d27d3e..3b3c3987 100644 --- a/pkg/server/middleware/limit.go +++ b/pkg/server/middleware/limit.go @@ -80,7 +80,7 @@ func cleanupVisitors() { mtx.Lock() for identifier, v := range visitors { - if time.Now().Sub(v.lastSeen) > 3*time.Minute { + if time.Since(v.lastSeen) > 3*time.Minute { delete(visitors, identifier) } } @@ -128,7 +128,7 @@ func Limit(next http.Handler) http.HandlerFunc { func ApplyLimit(h http.HandlerFunc, rateLimit bool) http.Handler { ret := h - if rateLimit && os.Getenv("GO_ENV") != "TEST" { + if rateLimit && os.Getenv("APP_ENV") != "TEST" { ret = Limit(ret) } diff --git a/pkg/server/middleware/main_test.go b/pkg/server/middleware/main_test.go deleted file mode 100644 index cd96508c..00000000 --- a/pkg/server/middleware/main_test.go +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors - * - * This file is part of Dnote. - * - * Dnote is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dnote is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Dnote. If not, see . - */ - -package middleware - -import ( - "os" - "testing" - - "github.com/dnote/dnote/pkg/server/testutils" -) - -func TestMain(m *testing.M) { - testutils.InitTestDB() - - code := m.Run() - testutils.ClearData(testutils.DB) - - os.Exit(code) -} diff --git a/pkg/server/middleware/middleware.go b/pkg/server/middleware/middleware.go index bf422e94..582e8046 100644 --- a/pkg/server/middleware/middleware.go +++ b/pkg/server/middleware/middleware.go @@ -20,32 +20,13 @@ package middleware import ( "net/http" - "net/url" "github.com/dnote/dnote/pkg/server/app" - "github.com/gorilla/schema" ) // Middleware is a middleware for request handlers type Middleware func(h http.Handler, app *app.App, rateLimit bool) http.Handler -type payload struct { - Method string `schema:"_method"` -} - -func parseValues(values url.Values, dst interface{}) error { - dec := schema.NewDecoder() - - // Ignore CSRF token field - dec.IgnoreUnknownKeys(true) - - if err := dec.Decode(dst, values); err != nil { - return err - } - - return nil -} - // methodOverrideKey is the form key for overriding the method var methodOverrideKey = "_method" diff --git a/pkg/server/operations/main_test.go b/pkg/server/operations/main_test.go deleted file mode 100644 index 19a59dbb..00000000 --- a/pkg/server/operations/main_test.go +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors - * - * This file is part of Dnote. - * - * Dnote is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dnote is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Dnote. If not, see . - */ - -package operations - -import ( - "os" - "testing" - - "github.com/dnote/dnote/pkg/server/testutils" -) - -func TestMain(m *testing.M) { - testutils.InitTestDB() - - code := m.Run() - testutils.ClearData(testutils.DB) - - os.Exit(code) -} diff --git a/pkg/server/operations/notes.go b/pkg/server/operations/notes.go index 88cac66d..ee6dd9af 100644 --- a/pkg/server/operations/notes.go +++ b/pkg/server/operations/notes.go @@ -19,13 +19,11 @@ package operations import ( - "errors" - "github.com/dnote/dnote/pkg/server/database" "github.com/dnote/dnote/pkg/server/helpers" "github.com/dnote/dnote/pkg/server/permissions" + "github.com/pkg/errors" "gorm.io/gorm" - pkgErrors "github.com/pkg/errors" ) // GetNote retrieves a note for the given user @@ -41,7 +39,7 @@ func GetNote(db *gorm.DB, uuid string, user *database.User) (database.Note, bool if errors.Is(err, gorm.ErrRecordNotFound) { return zeroNote, false, nil } else if err != nil { - return zeroNote, false, pkgErrors.Wrap(err, "finding note") + return zeroNote, false, errors.Wrap(err, "finding note") } if ok := permissions.ViewNote(user, note); !ok { diff --git a/pkg/server/operations/notes_test.go b/pkg/server/operations/notes_test.go index 020bbf75..6124b7e0 100644 --- a/pkg/server/operations/notes_test.go +++ b/pkg/server/operations/notes_test.go @@ -28,38 +28,41 @@ import ( ) func TestGetNote(t *testing.T) { - user := testutils.SetupUserData() - anotherUser := testutils.SetupUserData() + db := testutils.InitMemoryDB(t) - defer testutils.ClearData(testutils.DB) + user := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db) b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") privateNote := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "privateNote content", Deleted: false, Public: false, } - testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote") + testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote") publicNote := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "privateNote content", Deleted: false, Public: true, } - testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote") + testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote") var privateNoteRecord, publicNoteRecord database.Note - testutils.MustExec(t, testutils.DB.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote") - testutils.MustExec(t, testutils.DB.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote") + testutils.MustExec(t, db.Where("uuid = ?", privateNote.UUID).Preload("Book").Preload("User").First(&privateNoteRecord), "finding privateNote") + testutils.MustExec(t, db.Where("uuid = ?", publicNote.UUID).Preload("Book").Preload("User").First(&publicNoteRecord), "finding publicNote") testCases := []struct { name string @@ -107,7 +110,7 @@ func TestGetNote(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - note, ok, err := GetNote(testutils.DB, tc.note.UUID, &tc.user) + note, ok, err := GetNote(db, tc.note.UUID, &tc.user) if err != nil { t.Fatal(errors.Wrap(err, "executing")) } @@ -119,29 +122,29 @@ func TestGetNote(t *testing.T) { } func TestGetNote_nonexistent(t *testing.T) { - user := testutils.SetupUserData() + db := testutils.InitMemoryDB(t) - defer testutils.ClearData(testutils.DB) + user := testutils.SetupUserData(db) b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") - n1UUID := "4fd19336-671e-4ff3-8f22-662b80e22edc" n1 := database.Note{ - UUID: n1UUID, + UUID: "4fd19336-671e-4ff3-8f22-662b80e22edc", UserID: user.ID, BookUUID: b1.UUID, Body: "n1 content", Deleted: false, Public: false, } - testutils.MustExec(t, testutils.DB.Save(&n1), "preparing n1") + testutils.MustExec(t, db.Save(&n1), "preparing n1") nonexistentUUID := "4fd19336-671e-4ff3-8f22-662b80e22edd" - note, ok, err := GetNote(testutils.DB, nonexistentUUID, &user) + note, ok, err := GetNote(db, nonexistentUUID, &user) if err != nil { t.Fatal(errors.Wrap(err, "executing")) } diff --git a/pkg/server/permissions/permissions_test.go b/pkg/server/permissions/permissions_test.go index 607fb2a1..4054b66f 100644 --- a/pkg/server/permissions/permissions_test.go +++ b/pkg/server/permissions/permissions_test.go @@ -19,7 +19,6 @@ package permissions import ( - "os" "testing" "github.com/dnote/dnote/pkg/assert" @@ -27,44 +26,38 @@ import ( "github.com/dnote/dnote/pkg/server/testutils" ) -func TestMain(m *testing.M) { - testutils.InitTestDB() - - code := m.Run() - testutils.ClearData(testutils.DB) - - os.Exit(code) -} - func TestViewNote(t *testing.T) { - user := testutils.SetupUserData() - anotherUser := testutils.SetupUserData() + db := testutils.InitMemoryDB(t) - defer testutils.ClearData(testutils.DB) + user := testutils.SetupUserData(db) + anotherUser := testutils.SetupUserData(db) b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") privateNote := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "privateNote content", Deleted: false, Public: false, } - testutils.MustExec(t, testutils.DB.Save(&privateNote), "preparing privateNote") + testutils.MustExec(t, db.Save(&privateNote), "preparing privateNote") publicNote := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Body: "privateNote content", Deleted: false, Public: true, } - testutils.MustExec(t, testutils.DB.Save(&publicNote), "preparing privateNote") + testutils.MustExec(t, db.Save(&publicNote), "preparing privateNote") t.Run("owner accessing private note", func(t *testing.T) { result := ViewNote(&user, privateNote) diff --git a/pkg/server/presenters/book_test.go b/pkg/server/presenters/book_test.go new file mode 100644 index 00000000..98155769 --- /dev/null +++ b/pkg/server/presenters/book_test.go @@ -0,0 +1,217 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package presenters + +import ( + "testing" + "time" + + "github.com/dnote/dnote/pkg/assert" + "github.com/dnote/dnote/pkg/server/database" +) + +func TestPresentBook(t *testing.T) { + createdAt := time.Date(2025, 1, 15, 10, 30, 45, 123456789, time.UTC) + updatedAt := time.Date(2025, 2, 20, 14, 45, 30, 987654321, time.UTC) + + testCases := []struct { + name string + input database.Book + expected Book + }{ + { + name: "basic book", + input: database.Book{ + Model: database.Model{ + ID: 1, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + UUID: "a1b2c3d4-e5f6-4789-a012-3456789abcde", + UserID: 42, + Label: "JavaScript", + USN: 100, + }, + expected: Book{ + UUID: "a1b2c3d4-e5f6-4789-a012-3456789abcde", + USN: 100, + CreatedAt: FormatTS(createdAt), + UpdatedAt: FormatTS(updatedAt), + Label: "JavaScript", + }, + }, + { + name: "book with special characters in label", + input: database.Book{ + Model: database.Model{ + ID: 2, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + UUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098", + UserID: 99, + Label: "C++", + USN: 200, + }, + expected: Book{ + UUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098", + USN: 200, + CreatedAt: FormatTS(createdAt), + UpdatedAt: FormatTS(updatedAt), + Label: "C++", + }, + }, + { + name: "book with empty label", + input: database.Book{ + Model: database.Model{ + ID: 3, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + UUID: "12345678-90ab-4cde-8901-234567890abc", + UserID: 1, + Label: "", + USN: 0, + }, + expected: Book{ + UUID: "12345678-90ab-4cde-8901-234567890abc", + USN: 0, + CreatedAt: FormatTS(createdAt), + UpdatedAt: FormatTS(updatedAt), + Label: "", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := PresentBook(tc.input) + + assert.Equal(t, got.UUID, tc.expected.UUID, "UUID mismatch") + assert.Equal(t, got.USN, tc.expected.USN, "USN mismatch") + assert.Equal(t, got.Label, tc.expected.Label, "Label mismatch") + assert.Equal(t, got.CreatedAt, tc.expected.CreatedAt, "CreatedAt mismatch") + assert.Equal(t, got.UpdatedAt, tc.expected.UpdatedAt, "UpdatedAt mismatch") + }) + } +} + +func TestPresentBooks(t *testing.T) { + createdAt1 := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt1 := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC) + createdAt2 := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC) + updatedAt2 := time.Date(2025, 2, 2, 0, 0, 0, 0, time.UTC) + + testCases := []struct { + name string + input []database.Book + expected []Book + }{ + { + name: "empty slice", + input: []database.Book{}, + expected: []Book{}, + }, + { + name: "single book", + input: []database.Book{ + { + Model: database.Model{ + ID: 1, + CreatedAt: createdAt1, + UpdatedAt: updatedAt1, + }, + UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba", + UserID: 1, + Label: "Go", + USN: 10, + }, + }, + expected: []Book{ + { + UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba", + USN: 10, + CreatedAt: FormatTS(createdAt1), + UpdatedAt: FormatTS(updatedAt1), + Label: "Go", + }, + }, + }, + { + name: "multiple books", + input: []database.Book{ + { + Model: database.Model{ + ID: 1, + CreatedAt: createdAt1, + UpdatedAt: updatedAt1, + }, + UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba", + UserID: 1, + Label: "Go", + USN: 10, + }, + { + Model: database.Model{ + ID: 2, + CreatedAt: createdAt2, + UpdatedAt: updatedAt2, + }, + UUID: "abcdef01-2345-4678-9abc-def012345678", + UserID: 1, + Label: "Python", + USN: 20, + }, + }, + expected: []Book{ + { + UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba", + USN: 10, + CreatedAt: FormatTS(createdAt1), + UpdatedAt: FormatTS(updatedAt1), + Label: "Go", + }, + { + UUID: "abcdef01-2345-4678-9abc-def012345678", + USN: 20, + CreatedAt: FormatTS(createdAt2), + UpdatedAt: FormatTS(updatedAt2), + Label: "Python", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := PresentBooks(tc.input) + + assert.Equal(t, len(got), len(tc.expected), "Length mismatch") + + for i := range got { + assert.Equal(t, got[i].UUID, tc.expected[i].UUID, "UUID mismatch") + assert.Equal(t, got[i].USN, tc.expected[i].USN, "USN mismatch") + assert.Equal(t, got[i].Label, tc.expected[i].Label, "Label mismatch") + assert.Equal(t, got[i].CreatedAt, tc.expected[i].CreatedAt, "CreatedAt mismatch") + assert.Equal(t, got[i].UpdatedAt, tc.expected[i].UpdatedAt, "UpdatedAt mismatch") + } + }) + } +} diff --git a/pkg/server/presenters/email_preference.go b/pkg/server/presenters/email_preference.go deleted file mode 100644 index acf52eed..00000000 --- a/pkg/server/presenters/email_preference.go +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors - * - * This file is part of Dnote. - * - * Dnote is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dnote is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Dnote. If not, see . - */ - -package presenters - -import ( - "time" - - "github.com/dnote/dnote/pkg/server/database" -) - -// EmailPreference is a presented email digest -type EmailPreference struct { - InactiveReminder bool `json:"inactive_reminder"` - ProductUpdate bool `json:"product_update"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// PresentEmailPreference presents a digest -func PresentEmailPreference(p database.EmailPreference) EmailPreference { - ret := EmailPreference{ - InactiveReminder: p.InactiveReminder, - ProductUpdate: p.ProductUpdate, - CreatedAt: FormatTS(p.CreatedAt), - UpdatedAt: FormatTS(p.UpdatedAt), - } - - return ret -} diff --git a/pkg/server/tmpl/main_test.go b/pkg/server/presenters/helpers_test.go similarity index 71% rename from pkg/server/tmpl/main_test.go rename to pkg/server/presenters/helpers_test.go index ceabf5cc..c48c90ea 100644 --- a/pkg/server/tmpl/main_test.go +++ b/pkg/server/presenters/helpers_test.go @@ -16,20 +16,20 @@ * along with Dnote. If not, see . */ -package tmpl +package presenters import ( - "os" "testing" + "time" - "github.com/dnote/dnote/pkg/server/testutils" + "github.com/dnote/dnote/pkg/assert" ) -func TestMain(m *testing.M) { - testutils.InitTestDB() +func TestFormatTS(t *testing.T) { + input := time.Date(2025, 1, 15, 10, 30, 45, 123456789, time.UTC) + expected := time.Date(2025, 1, 15, 10, 30, 45, 123457000, time.UTC) - code := m.Run() - testutils.ClearData(testutils.DB) + got := FormatTS(input) - os.Exit(code) + assert.Equal(t, got, expected, "FormatTS mismatch") } diff --git a/pkg/server/presenters/note_test.go b/pkg/server/presenters/note_test.go new file mode 100644 index 00000000..822c5cea --- /dev/null +++ b/pkg/server/presenters/note_test.go @@ -0,0 +1,127 @@ +/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors + * + * This file is part of Dnote. + * + * Dnote is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * Dnote is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with Dnote. If not, see . + */ + +package presenters + +import ( + "testing" + "time" + + "github.com/dnote/dnote/pkg/assert" + "github.com/dnote/dnote/pkg/server/database" +) + +func TestPresentNote(t *testing.T) { + createdAt := time.Date(2025, 1, 15, 10, 30, 45, 123456789, time.UTC) + updatedAt := time.Date(2025, 2, 20, 14, 45, 30, 987654321, time.UTC) + + input := database.Note{ + Model: database.Model{ + ID: 1, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + UUID: "a1b2c3d4-e5f6-4789-a012-3456789abcde", + UserID: 42, + BookUUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098", + Body: "Test note content", + AddedOn: 1234567890, + Public: true, + USN: 100, + Book: database.Book{ + UUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098", + Label: "JavaScript", + }, + User: database.User{ + UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba", + }, + } + + got := PresentNote(input) + + assert.Equal(t, got.UUID, "a1b2c3d4-e5f6-4789-a012-3456789abcde", "UUID mismatch") + assert.Equal(t, got.Body, "Test note content", "Body mismatch") + assert.Equal(t, got.AddedOn, int64(1234567890), "AddedOn mismatch") + assert.Equal(t, got.Public, true, "Public mismatch") + assert.Equal(t, got.USN, 100, "USN mismatch") + assert.Equal(t, got.CreatedAt, FormatTS(createdAt), "CreatedAt mismatch") + assert.Equal(t, got.UpdatedAt, FormatTS(updatedAt), "UpdatedAt mismatch") + assert.Equal(t, got.Book.UUID, "f1e2d3c4-b5a6-4987-b654-321fedcba098", "Book UUID mismatch") + assert.Equal(t, got.Book.Label, "JavaScript", "Book Label mismatch") + assert.Equal(t, got.User.UUID, "9a8b7c6d-5e4f-4321-9876-543210fedcba", "User UUID mismatch") +} + +func TestPresentNotes(t *testing.T) { + createdAt1 := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt1 := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC) + createdAt2 := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC) + updatedAt2 := time.Date(2025, 2, 2, 0, 0, 0, 0, time.UTC) + + input := []database.Note{ + { + Model: database.Model{ + ID: 1, + CreatedAt: createdAt1, + UpdatedAt: updatedAt1, + }, + UUID: "a1b2c3d4-e5f6-4789-a012-3456789abcde", + UserID: 1, + BookUUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098", + Body: "First note", + AddedOn: 1000000000, + Public: false, + USN: 10, + Book: database.Book{ + UUID: "f1e2d3c4-b5a6-4987-b654-321fedcba098", + Label: "Go", + }, + User: database.User{ + UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba", + }, + }, + { + Model: database.Model{ + ID: 2, + CreatedAt: createdAt2, + UpdatedAt: updatedAt2, + }, + UUID: "12345678-90ab-4cde-8901-234567890abc", + UserID: 1, + BookUUID: "abcdef01-2345-4678-9abc-def012345678", + Body: "Second note", + AddedOn: 2000000000, + Public: true, + USN: 20, + Book: database.Book{ + UUID: "abcdef01-2345-4678-9abc-def012345678", + Label: "Python", + }, + User: database.User{ + UUID: "9a8b7c6d-5e4f-4321-9876-543210fedcba", + }, + }, + } + + got := PresentNotes(input) + + assert.Equal(t, len(got), 2, "Length mismatch") + assert.Equal(t, got[0].UUID, "a1b2c3d4-e5f6-4789-a012-3456789abcde", "Note 0 UUID mismatch") + assert.Equal(t, got[0].Body, "First note", "Note 0 Body mismatch") + assert.Equal(t, got[1].UUID, "12345678-90ab-4cde-8901-234567890abc", "Note 1 UUID mismatch") + assert.Equal(t, got[1].Body, "Second note", "Note 1 Body mismatch") +} diff --git a/pkg/server/session/session.go b/pkg/server/session/session.go index 5ad92fec..8c55549c 100644 --- a/pkg/server/session/session.go +++ b/pkg/server/session/session.go @@ -27,14 +27,12 @@ type Session struct { UUID string `json:"uuid"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` - Pro bool `json:"pro"` } // New returns a new session for the given user func New(user database.User, account database.Account) Session { return Session{ UUID: user.UUID, - Pro: user.Cloud, Email: account.Email.String, EmailVerified: account.EmailVerified, } diff --git a/pkg/server/session/session_test.go b/pkg/server/session/session_test.go index 967053e1..107dacfe 100644 --- a/pkg/server/session/session_test.go +++ b/pkg/server/session/session_test.go @@ -27,36 +27,32 @@ import ( ) func TestNew(t *testing.T) { - u1 := database.User{UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb", Cloud: true} + u1 := database.User{UUID: "0f5f0054-d23f-4be1-b5fb-57673109e9cb"} a1 := database.Account{Email: database.ToNullString("alice@example.com"), EmailVerified: false} - u2 := database.User{UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e", Cloud: false} + u2 := database.User{UUID: "718a1041-bbe6-496e-bbe4-ea7e572c295e"} a2 := database.Account{Email: database.ToNullString("bob@example.com"), EmailVerified: false} testCases := []struct { - user database.User - account database.Account - expectedPro bool + user database.User + account database.Account }{ { - user: u1, - account: a1, - expectedPro: true, + user: u1, + account: a1, }, { - user: u2, - account: a2, - expectedPro: false, + user: u2, + account: a2, }, } - for _, tc := range testCases { - t.Run(fmt.Sprintf("user pro %t", tc.expectedPro), func(t *testing.T) { + for idx, tc := range testCases { + t.Run(fmt.Sprintf("user %d", idx), func(t *testing.T) { // Execute got := New(tc.user, tc.account) expected := Session{ UUID: tc.user.UUID, - Pro: tc.expectedPro, Email: tc.account.Email.String, EmailVerified: tc.account.EmailVerified, } diff --git a/pkg/server/testutils/main.go b/pkg/server/testutils/main.go index 33a4a00d..08bc07de 100644 --- a/pkg/server/testutils/main.go +++ b/pkg/server/testutils/main.go @@ -27,36 +27,44 @@ import ( "net/http" "net/url" "reflect" - // "strconv" "strings" "sync" "testing" "time" - "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/database" - "gorm.io/gorm" + "github.com/dnote/dnote/pkg/server/helpers" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) -func init() { - rand.Seed(time.Now().UnixNano()) +// InitDB opens a database at the given path and initializes the schema +func InitDB(dbPath string) *gorm.DB { + db := database.Open(dbPath) + database.InitSchema(db) + database.Migrate(db) + return db } -// DB is the database connection to a test database -var DB *gorm.DB - -// InitTestDB establishes connection pool with the test database specified by -// the environment variable configuration and initalizes a new schema -func InitTestDB() { - c := config.Load() - fmt.Println(c.DB.GetConnectionStr()) - db := database.Open(c) +// InitMemoryDB creates an in-memory SQLite database with the schema initialized +func InitMemoryDB(t *testing.T) *gorm.DB { + // Use file-based in-memory database with unique UUID per test to avoid sharing + uuid, err := helpers.GenUUID() + if err != nil { + t.Fatalf("failed to generate UUID for test database: %v", err) + } + dbName := fmt.Sprintf("file:%s?mode=memory&cache=shared", uuid) + db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to open in-memory database: %v", err) + } database.InitSchema(db) + database.Migrate(db) - DB = db + return db } // ClearData deletes all records from the database @@ -68,15 +76,9 @@ func ClearData(db *gorm.DB) { if err := db.Where("1 = 1").Delete(&database.Book{}).Error; err != nil { panic(errors.Wrap(err, "Failed to clear books")) } - if err := db.Where("1 = 1").Delete(&database.Notification{}).Error; err != nil { - panic(errors.Wrap(err, "Failed to clear notifications")) - } if err := db.Where("1 = 1").Delete(&database.Token{}).Error; err != nil { panic(errors.Wrap(err, "Failed to clear tokens")) } - if err := db.Where("1 = 1").Delete(&database.EmailPreference{}).Error; err != nil { - panic(errors.Wrap(err, "Failed to clear email preferences")) - } if err := db.Where("1 = 1").Delete(&database.Session{}).Error; err != nil { panic(errors.Wrap(err, "Failed to clear sessions")) } @@ -88,13 +90,27 @@ func ClearData(db *gorm.DB) { } } +// MustUUID generates a UUID and fails the test on error +func MustUUID(t *testing.T) string { + uuid, err := helpers.GenUUID() + if err != nil { + t.Fatal(errors.Wrap(err, "Failed to generate UUID")) + } + return uuid +} + // SetupUserData creates and returns a new user for testing purposes -func SetupUserData() database.User { - user := database.User{ - Cloud: true, +func SetupUserData(db *gorm.DB) database.User { + uuid, err := helpers.GenUUID() + if err != nil { + panic(errors.Wrap(err, "Failed to generate UUID")) } - if err := DB.Save(&user).Error; err != nil { + user := database.User{ + UUID: uuid, + } + + if err := db.Save(&user).Error; err != nil { panic(errors.Wrap(err, "Failed to prepare user")) } @@ -102,7 +118,7 @@ func SetupUserData() database.User { } // SetupAccountData creates and returns a new account for the user -func SetupAccountData(user database.User, email, password string) database.Account { +func SetupAccountData(db *gorm.DB, user database.User, email, password string) database.Account { account := database.Account{ UserID: user.ID, } @@ -116,7 +132,7 @@ func SetupAccountData(user database.User, email, password string) database.Accou } account.Password = database.ToNullString(string(hashedPassword)) - if err := DB.Save(&account).Error; err != nil { + if err := db.Save(&account).Error; err != nil { panic(errors.Wrap(err, "Failed to prepare account")) } @@ -124,33 +140,19 @@ func SetupAccountData(user database.User, email, password string) database.Accou } // SetupSession creates and returns a new user session -func SetupSession(t *testing.T, user database.User) database.Session { +func SetupSession(db *gorm.DB, user database.User) database.Session { session := database.Session{ Key: "Vvgm3eBXfXGEFWERI7faiRJ3DAzJw+7DdT9J1LEyNfI=", UserID: user.ID, ExpiresAt: time.Now().Add(time.Hour * 24), } - if err := DB.Save(&session).Error; err != nil { - t.Fatal(errors.Wrap(err, "Failed to prepare user")) + if err := db.Save(&session).Error; err != nil { + panic(errors.Wrap(err, "Failed to prepare user")) } return session } -// SetupEmailPreferenceData creates and returns a new email frequency for a user -func SetupEmailPreferenceData(user database.User, inactiveReminder bool) database.EmailPreference { - frequency := database.EmailPreference{ - UserID: user.ID, - InactiveReminder: inactiveReminder, - } - - if err := DB.Save(&frequency).Error; err != nil { - panic(errors.Wrap(err, "Failed to prepare email frequency")) - } - - return frequency -} - // HTTPDo makes an HTTP request and returns a response func HTTPDo(t *testing.T, req *http.Request) *http.Response { hc := http.Client{ @@ -170,8 +172,8 @@ func HTTPDo(t *testing.T, req *http.Request) *http.Response { return res } -// SetReqAuthHeader sets the authorization header in the given request for the given user -func SetReqAuthHeader(t *testing.T, req *http.Request, user database.User) { +// SetReqAuthHeader sets the authorization header in the given request for the given user with a specific DB +func SetReqAuthHeader(t *testing.T, db *gorm.DB, req *http.Request, user database.User) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { t.Fatal(errors.Wrap(err, "reading random bits")) @@ -182,19 +184,18 @@ func SetReqAuthHeader(t *testing.T, req *http.Request, user database.User) { UserID: user.ID, ExpiresAt: time.Now().Add(time.Hour * 10 * 24), } - if err := DB.Save(&session).Error; err != nil { + if err := db.Save(&session).Error; err != nil { t.Fatal(errors.Wrap(err, "Failed to prepare user")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session.Key)) } -// HTTPAuthDo makes an HTTP request with an appropriate authorization header for a user -func HTTPAuthDo(t *testing.T, req *http.Request, user database.User) *http.Response { - SetReqAuthHeader(t, req, user) +// HTTPAuthDo makes an HTTP request with an appropriate authorization header for a user with a specific DB +func HTTPAuthDo(t *testing.T, db *gorm.DB, req *http.Request, user database.User) *http.Response { + SetReqAuthHeader(t, db, req, user) return HTTPDo(t, req) - } // MakeReq makes an HTTP request and returns a response diff --git a/pkg/server/tmpl/app_test.go b/pkg/server/tmpl/app_test.go index f7042574..fba9bed4 100644 --- a/pkg/server/tmpl/app_test.go +++ b/pkg/server/tmpl/app_test.go @@ -31,7 +31,9 @@ import ( func TestAppShellExecute(t *testing.T) { t.Run("home", func(t *testing.T) { - a, err := NewAppShell(testutils.DB, []byte("{{ .Title }}{{ .MetaTags }}")) + db := testutils.InitMemoryDB(t) + + a, err := NewAppShell(db, []byte("{{ .Title }}{{ .MetaTags }}")) if err != nil { t.Fatal(errors.Wrap(err, "preparing app shell")) } @@ -50,23 +52,25 @@ func TestAppShellExecute(t *testing.T) { }) t.Run("note", func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) - user := testutils.SetupUserData() + user := testutils.SetupUserData(db) b1 := database.Book{ + UUID: testutils.MustUUID(t), UserID: user.ID, Label: "js", } - testutils.MustExec(t, testutils.DB.Save(&b1), "preparing b1") + testutils.MustExec(t, db.Save(&b1), "preparing b1") n1 := database.Note{ + UUID: testutils.MustUUID(t), UserID: user.ID, BookUUID: b1.UUID, Public: true, Body: "n1 content", } - testutils.MustExec(t, testutils.DB.Save(&n1), "preparing note") + testutils.MustExec(t, db.Save(&n1), "preparing note") - a, err := NewAppShell(testutils.DB, []byte("{{ .MetaTags }}")) + a, err := NewAppShell(db, []byte("{{ .MetaTags }}")) if err != nil { t.Fatal(errors.Wrap(err, "preparing app shell")) } diff --git a/pkg/server/tmpl/data_test.go b/pkg/server/tmpl/data_test.go index 8c60e3d3..c072d12e 100644 --- a/pkg/server/tmpl/data_test.go +++ b/pkg/server/tmpl/data_test.go @@ -42,7 +42,8 @@ func TestNotePageGetData(t *testing.T) { // Set time.Local to UTC for deterministic test time.Local = time.UTC - a, err := NewAppShell(testutils.DB, nil) + db := testutils.InitMemoryDB(t) + a, err := NewAppShell(db, nil) if err != nil { t.Fatal(errors.Wrap(err, "preparing app shell")) } diff --git a/pkg/server/token/main_test.go b/pkg/server/token/main_test.go deleted file mode 100644 index 84e4f0b2..00000000 --- a/pkg/server/token/main_test.go +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors - * - * This file is part of Dnote. - * - * Dnote is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * Dnote is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with Dnote. If not, see . - */ - -package token - -import ( - "os" - "testing" - - "github.com/dnote/dnote/pkg/server/testutils" -) - -func TestMain(m *testing.M) { - testutils.InitTestDB() - - code := m.Run() - testutils.ClearData(testutils.DB) - - os.Exit(code) -} diff --git a/pkg/server/token/token_test.go b/pkg/server/token/token_test.go index 371ab71c..922cc93d 100644 --- a/pkg/server/token/token_test.go +++ b/pkg/server/token/token_test.go @@ -33,30 +33,30 @@ func TestCreate(t *testing.T) { kind string }{ { - kind: database.TokenTypeEmailPreference, + kind: database.TokenTypeEmailVerification, }, } for _, tc := range testCases { t.Run(fmt.Sprintf("token type %s", tc.kind), func(t *testing.T) { - defer testutils.ClearData(testutils.DB) + db := testutils.InitMemoryDB(t) // Set up - u := testutils.SetupUserData() + u := testutils.SetupUserData(db) // Execute - tok, err := Create(testutils.DB, u.ID, tc.kind) + tok, err := Create(db, u.ID, tc.kind) if err != nil { t.Fatal(errors.Wrap(err, "performing")) } // Test var count int64 - testutils.MustExec(t, testutils.DB.Model(&database.Token{}).Count(&count), "counting token") + testutils.MustExec(t, db.Model(&database.Token{}).Count(&count), "counting token") assert.Equalf(t, count, int64(1), "error mismatch") var tokenRecord database.Token - testutils.MustExec(t, testutils.DB.First(&tokenRecord), "finding token") + testutils.MustExec(t, db.First(&tokenRecord), "finding token") assert.Equalf(t, tokenRecord.UserID, tok.UserID, "UserID mismatch") assert.Equalf(t, tokenRecord.Value, tok.Value, "Value mismatch") assert.Equalf(t, tokenRecord.Type, tok.Type, "Type mismatch") diff --git a/pkg/server/views/helpers.go b/pkg/server/views/helpers.go index 867033af..ac4546f7 100644 --- a/pkg/server/views/helpers.go +++ b/pkg/server/views/helpers.go @@ -50,7 +50,7 @@ func initHelpers(c Config, a *app.App) template.FuncMap { "defaultValue": ctx.defaultValue, "add": ctx.add, "assetBaseURL": func() string { - return a.Config.AssetBaseURL + return a.AssetBaseURL }, } diff --git a/pkg/server/views/templates/users/settings_about.gohtml b/pkg/server/views/templates/users/settings_about.gohtml index bba9a8f0..3252b9de 100644 --- a/pkg/server/views/templates/users/settings_about.gohtml +++ b/pkg/server/views/templates/users/settings_about.gohtml @@ -27,27 +27,6 @@ - {{if ne .Standalone "true"}} -
-
-
-

Support

-
- -
- {{if .User.Cloud}} - - support@getdnote.com - - {{else}} - Not eligible - {{end}} -
-
-
- {{else}} - - {{end}} diff --git a/pkg/server/views/view.go b/pkg/server/views/view.go index 75a9aa92..2b0e57a9 100644 --- a/pkg/server/views/view.go +++ b/pkg/server/views/view.go @@ -119,9 +119,6 @@ func (v *View) Render(w http.ResponseWriter, r *http.Request, data *Data, status vd.Yield["EmailVerified"] = vd.Account.EmailVerified vd.Yield["EmailVerified"] = vd.Account.EmailVerified } - if vd.User != nil { - vd.Yield["Cloud"] = vd.User.Cloud - } vd.Yield["CurrentPath"] = r.URL.Path vd.Yield["Standalone"] = buildinfo.Standalone diff --git a/scripts/server/dev.sh b/scripts/server/dev.sh index 11f1e207..5eb93328 100755 --- a/scripts/server/dev.sh +++ b/scripts/server/dev.sh @@ -23,7 +23,7 @@ cp "$basePath"/pkg/server/assets/static/* "$basePath/pkg/server/static" # run server moduleName="github.com/dnote/dnote" ldflags="-X '$moduleName/pkg/server/buildinfo.CSSFiles=main.css' -X '$moduleName/pkg/server/buildinfo.JSFiles=main.js' -X '$moduleName/pkg/server/buildinfo.Version=dev' -X '$moduleName/pkg/server/buildinfo.Standalone=true'" -task="go run -ldflags \"$ldflags\" main.go start -port 3000" +task="go run -ldflags \"$ldflags\" --tags fts5 main.go start -port 3000" ( cd "$basePath/pkg/watcher" && \ diff --git a/scripts/server/test.sh b/scripts/server/test.sh index ed1d0a41..10207010 100755 --- a/scripts/server/test.sh +++ b/scripts/server/test.sh @@ -8,9 +8,9 @@ pushd "$dir/../../pkg/server" function run_test { if [ -z "$1" ]; then - go test ./... -cover -p 1 + go test -tags "fts5" ./... -cover else - go test -run "$1" -cover -p 1 + go test -tags "fts5" -run "$1" -cover fi }