/* Copyright (C) 2019, 2020 Monomax Software Pty Ltd * * This file is part of Dnote. * * Dnote is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Dnote is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with Dnote. If not, see . */ package api import ( "encoding/json" "fmt" "net/http" "testing" "time" "github.com/dnote/dnote/pkg/assert" "github.com/dnote/dnote/pkg/clock" "github.com/dnote/dnote/pkg/server/app" "github.com/dnote/dnote/pkg/server/config" "github.com/dnote/dnote/pkg/server/models" "github.com/dnote/dnote/pkg/server/testutils" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" ) func assertSessionResp(t *testing.T, res *http.Response) { // after register, should sign in user var got SessionResponse if err := json.NewDecoder(res.Body).Decode(&got); err != nil { t.Fatal(errors.Wrap(err, "decoding payload")) } var sessionCount int var session models.Session models.MustExec(t, models.TestDB.Model(&models.Session{}).Count(&sessionCount), "counting session") models.MustExec(t, models.TestDB.First(&session), "getting session") assert.Equal(t, sessionCount, 1, "sessionCount mismatch") assert.Equal(t, got.Key, session.Key, "session Key mismatch") assert.Equal(t, got.ExpiresAt, session.ExpiresAt.Unix(), "session ExpiresAt mismatch") c := testutils.GetCookieByName(res.Cookies(), "id") assert.Equal(t, c.Value, session.Key, "session key mismatch") assert.Equal(t, c.Path, "/", "session path mismatch") assert.Equal(t, c.HttpOnly, true, "session HTTPOnly mismatch") assert.Equal(t, c.Expires.Unix(), session.ExpiresAt.Unix(), "session Expires mismatch") } func TestRegister(t *testing.T) { testCases := []struct { email string password string onPremise bool expectedPro bool }{ { email: "alice@example.com", password: "pass1234", onPremise: false, expectedPro: false, }, { email: "bob@example.com", password: "Y9EwmjH@Jq6y5a64MSACUoM4w7SAhzvY", onPremise: false, expectedPro: false, }, { email: "chuck@example.com", password: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC", onPremise: false, expectedPro: false, }, // on premise { email: "dan@example.com", password: "e*H@kJi^vXbWEcD9T5^Am!Y@7#Po2@PC", onPremise: true, expectedPro: true, }, } for _, tc := range testCases { t.Run(fmt.Sprintf("register %s %s", tc.email, tc.password), func(t *testing.T) { defer models.ClearTestData(models.TestDB) c := config.Load() c.SetOnPremise(tc.onPremise) // Setup emailBackend := testutils.MockEmailbackendImplementation{} server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), EmailBackend: &emailBackend, Config: c, }) defer server.Close() dat := fmt.Sprintf(`{"email": "%s", "password": "%s"}`, tc.email, tc.password) req := testutils.MakeReq(server.URL, "POST", "/v3/register", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusCreated, "") var account models.Account models.MustExec(t, models.TestDB.Where("email = ?", tc.email).First(&account), "finding account") assert.Equal(t, account.Email.String, tc.email, "Email mismatch") assert.NotEqual(t, account.UserID, 0, "UserID mismatch") passwordErr := bcrypt.CompareHashAndPassword([]byte(account.Password.String), []byte(tc.password)) assert.Equal(t, passwordErr, nil, "Password mismatch") var user models.User models.MustExec(t, models.TestDB.Where("id = ?", account.UserID).First(&user), "finding user") assert.Equal(t, user.Cloud, tc.expectedPro, "Cloud mismatch") assert.Equal(t, user.MaxUSN, 0, "MaxUSN mismatch") // welcome email assert.Equalf(t, len(emailBackend.Emails), 1, "email queue count mismatch") assert.DeepEqual(t, emailBackend.Emails[0].To, []string{tc.email}, "email to mismatch") // after register, should sign in user assertSessionResp(t, res) }) } } func TestRegisterMissingParams(t *testing.T) { t.Run("missing email", func(t *testing.T) { defer models.ClearTestData(models.TestDB) // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() dat := fmt.Sprintf(`{"password": %s}`, "SLMZFM5RmSjA5vfXnG5lPOnrpZSbtmV76cnAcrlr2yU") req := testutils.MakeReq(server.URL, "POST", "/v3/register", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch") var accountCount, userCount int models.MustExec(t, models.TestDB.Model(&models.Account{}).Count(&accountCount), "counting account") models.MustExec(t, models.TestDB.Model(&models.User{}).Count(&userCount), "counting user") assert.Equal(t, accountCount, 0, "accountCount mismatch") assert.Equal(t, userCount, 0, "userCount mismatch") }) t.Run("missing password", func(t *testing.T) { defer models.ClearTestData(models.TestDB) // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() dat := fmt.Sprintf(`{"email": "%s"}`, "alice@example.com") req := testutils.MakeReq(server.URL, "POST", "/v3/register", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "Status mismatch") var accountCount, userCount int models.MustExec(t, models.TestDB.Model(&models.Account{}).Count(&accountCount), "counting account") models.MustExec(t, models.TestDB.Model(&models.User{}).Count(&userCount), "counting user") assert.Equal(t, accountCount, 0, "accountCount mismatch") assert.Equal(t, userCount, 0, "userCount mismatch") }) } func TestRegisterDuplicateEmail(t *testing.T) { defer models.ClearTestData(models.TestDB) // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() u := models.SetUpUserData() models.SetUpAccountData(u, "alice@example.com", "somepassword") dat := `{"email": "alice@example.com", "password": "foobarbaz"}` req := testutils.MakeReq(server.URL, "POST", "/v3/register", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusBadRequest, "status code mismatch") var accountCount, userCount, verificationTokenCount int models.MustExec(t, models.TestDB.Model(&models.Account{}).Count(&accountCount), "counting account") models.MustExec(t, models.TestDB.Model(&models.User{}).Count(&userCount), "counting user") models.MustExec(t, models.TestDB.Model(&models.Token{}).Count(&verificationTokenCount), "counting verification token") var user models.User models.MustExec(t, models.TestDB.Where("id = ?", u.ID).First(&user), "finding user") assert.Equal(t, accountCount, 1, "account count mismatch") assert.Equal(t, userCount, 1, "user count mismatch") assert.Equal(t, verificationTokenCount, 0, "verification_token should not have been created") assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch") } func TestRegisterDisabled(t *testing.T) { defer models.ClearTestData(models.TestDB) c := config.Load() c.DisableRegistration = true // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), Config: c, }) defer server.Close() dat := `{"email": "alice@example.com", "password": "foobarbaz"}` req := testutils.MakeReq(server.URL, "POST", "/v3/register", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusForbidden, "status code mismatch") var accountCount, userCount int models.MustExec(t, models.TestDB.Model(&models.Account{}).Count(&accountCount), "counting account") models.MustExec(t, models.TestDB.Model(&models.User{}).Count(&userCount), "counting user") assert.Equal(t, accountCount, 0, "account count mismatch") assert.Equal(t, userCount, 0, "user count mismatch") } func TestSignIn(t *testing.T) { t.Run("success", func(t *testing.T) { defer models.ClearTestData(models.TestDB) // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() u := models.SetUpUserData() models.SetUpAccountData(u, "alice@example.com", "pass1234") dat := `{"email": "alice@example.com", "password": "pass1234"}` req := testutils.MakeReq(server.URL, "POST", "/v3/signin", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusOK, "") var user models.User models.MustExec(t, models.TestDB.Model(&models.User{}).First(&user), "finding user") assert.NotEqual(t, user.LastLoginAt, nil, "LastLoginAt mismatch") // after register, should sign in user assertSessionResp(t, res) }) t.Run("wrong password", func(t *testing.T) { defer models.ClearTestData(models.TestDB) // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() u := models.SetUpUserData() models.SetUpAccountData(u, "alice@example.com", "pass1234") dat := `{"email": "alice@example.com", "password": "wrongpassword1234"}` req := testutils.MakeReq(server.URL, "POST", "/v3/signin", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "") var user models.User models.MustExec(t, models.TestDB.Model(&models.User{}).First(&user), "finding user") assert.Equal(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch") var sessionCount int models.MustExec(t, models.TestDB.Model(&models.Session{}).Count(&sessionCount), "counting session") assert.Equal(t, sessionCount, 0, "sessionCount mismatch") }) t.Run("wrong email", func(t *testing.T) { defer models.ClearTestData(models.TestDB) // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() u := models.SetUpUserData() models.SetUpAccountData(u, "alice@example.com", "pass1234") dat := `{"email": "bob@example.com", "password": "pass1234"}` req := testutils.MakeReq(server.URL, "POST", "/v3/signin", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "") var user models.User models.MustExec(t, models.TestDB.Model(&models.User{}).First(&user), "finding user") assert.DeepEqual(t, user.LastLoginAt, (*time.Time)(nil), "LastLoginAt mismatch") var sessionCount int models.MustExec(t, models.TestDB.Model(&models.Session{}).Count(&sessionCount), "counting session") assert.Equal(t, sessionCount, 0, "sessionCount mismatch") }) t.Run("nonexistent email", func(t *testing.T) { defer models.ClearTestData(models.TestDB) // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() dat := `{"email": "nonexistent@example.com", "password": "pass1234"}` req := testutils.MakeReq(server.URL, "POST", "/v3/signin", dat) // Execute res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusUnauthorized, "") var sessionCount int models.MustExec(t, models.TestDB.Model(&models.Session{}).Count(&sessionCount), "counting session") assert.Equal(t, sessionCount, 0, "sessionCount mismatch") }) } func TestSignout(t *testing.T) { t.Run("authenticated", func(t *testing.T) { defer models.ClearTestData(models.TestDB) aliceUser := models.SetUpUserData() models.SetUpAccountData(aliceUser, "alice@example.com", "pass1234") anotherUser := models.SetUpUserData() session1 := models.Session{ Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", UserID: aliceUser.ID, ExpiresAt: time.Now().Add(time.Hour * 24), } models.MustExec(t, models.TestDB.Save(&session1), "preparing session1") session2 := models.Session{ Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=", UserID: anotherUser.ID, ExpiresAt: time.Now().Add(time.Hour * 24), } models.MustExec(t, models.TestDB.Save(&session2), "preparing session2") // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() // Execute req := testutils.MakeReq(server.URL, "POST", "/v3/signout", "") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=")) res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusNoContent, "Status mismatch") var sessionCount int var s2 models.Session models.MustExec(t, models.TestDB.Model(&models.Session{}).Count(&sessionCount), "counting session") models.MustExec(t, models.TestDB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&s2), "getting s2") assert.Equal(t, sessionCount, 1, "sessionCount mismatch") c := testutils.GetCookieByName(res.Cookies(), "id") assert.Equal(t, c.Value, "", "session key mismatch") assert.Equal(t, c.Path, "/", "session path mismatch") assert.Equal(t, c.HttpOnly, true, "session HTTPOnly mismatch") if c.Expires.After(time.Now()) { t.Error("session cookie is not expired") } }) t.Run("unauthenticated", func(t *testing.T) { defer models.ClearTestData(models.TestDB) aliceUser := models.SetUpUserData() models.SetUpAccountData(aliceUser, "alice@example.com", "pass1234") anotherUser := models.SetUpUserData() session1 := models.Session{ Key: "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=", UserID: aliceUser.ID, ExpiresAt: time.Now().Add(time.Hour * 24), } models.MustExec(t, models.TestDB.Save(&session1), "preparing session1") session2 := models.Session{ Key: "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=", UserID: anotherUser.ID, ExpiresAt: time.Now().Add(time.Hour * 24), } models.MustExec(t, models.TestDB.Save(&session2), "preparing session2") // Setup server := MustNewServer(t, &app.App{ Clock: clock.NewMock(), }) defer server.Close() // Execute req := testutils.MakeReq(server.URL, "POST", "/v3/signout", "") res := testutils.HTTPDo(t, req) // Test assert.StatusCodeEquals(t, res, http.StatusNoContent, "Status mismatch") var sessionCount int var postSession1, postSession2 models.Session models.MustExec(t, models.TestDB.Model(&models.Session{}).Count(&sessionCount), "counting session") models.MustExec(t, models.TestDB.Where("key = ?", "A9xgggqzTHETy++GDi1NpDNe0iyqosPm9bitdeNGkJU=").First(&postSession1), "getting postSession1") models.MustExec(t, models.TestDB.Where("key = ?", "MDCpbvCRg7W2sH6S870wqLqZDZTObYeVd0PzOekfo/A=").First(&postSession2), "getting postSession2") // two existing sessions should remain assert.Equal(t, sessionCount, 2, "sessionCount mismatch") c := testutils.GetCookieByName(res.Cookies(), "id") assert.Equal(t, c, (*http.Cookie)(nil), "id cookie should have not been set") }) }