/* Copyright 2025 Dnote Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package client import ( "encoding/json" "fmt" "net/http" "net/http/httptest" "sync/atomic" "testing" "time" "github.com/dnote/dnote/pkg/assert" "github.com/dnote/dnote/pkg/cli/context" "github.com/dnote/dnote/pkg/cli/testutils" "github.com/pkg/errors" "golang.org/x/time/rate" ) // startCommonTestServer starts a test HTTP server that simulates a common set of senarios func startCommonTestServer() *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // internal server error if r.URL.String() == "/bad-api/v3/signout" && r.Method == "POST" { w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusInternalServerError) return } // catch-all w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusOK) w.Write([]byte(`
`)) })) return ts } func TestSignIn(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/api/v3/signin" && r.Method == "POST" { var payload SigninPayload err := json.NewDecoder(r.Body).Decode(&payload) if err != nil { t.Fatal(errors.Wrap(err, "decoding payload in the test server").Error()) return } if payload.Email == "alice@example.com" && payload.Passowrd == "pass1234" { resp := testutils.MustMarshalJSON(t, SigninResponse{ Key: "somekey", ExpiresAt: int64(1596439890), }) w.Header().Set("Content-Type", "application/json") w.Write(resp) } else { w.WriteHeader(http.StatusUnauthorized) } return } })) defer ts.Close() commonTs := startCommonTestServer() defer commonTs.Close() correctEndpoint := fmt.Sprintf("%s/api", ts.URL) testClient := NewRateLimitedHTTPClient() t.Run("success", func(t *testing.T) { result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com", "pass1234") if err != nil { t.Errorf("got signin request error: %+v", err.Error()) } assert.Equal(t, result.Key, "somekey", "Key mismatch") assert.Equal(t, result.ExpiresAt, int64(1596439890), "ExpiresAt mismatch") }) t.Run("failure", func(t *testing.T) { result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com", "incorrectpassword") assert.Equal(t, err, ErrInvalidLogin, "err mismatch") assert.Equal(t, result.Key, "", "Key mismatch") assert.Equal(t, result.ExpiresAt, int64(0), "ExpiresAt mismatch") }) t.Run("server error", func(t *testing.T) { endpoint := fmt.Sprintf("%s/bad-api", ts.URL) result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com", "pass1234") if err == nil { t.Error("error should have been returned") } assert.Equal(t, result.Key, "", "Key mismatch") assert.Equal(t, result.ExpiresAt, int64(0), "ExpiresAt mismatch") }) t.Run("accidentally pointing to a catch-all handler", func(t *testing.T) { endpoint := fmt.Sprintf("%s", ts.URL) result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com", "pass1234") assert.Equal(t, errors.Cause(err), ErrContentTypeMismatch, "error cause mismatch") assert.Equal(t, result.Key, "", "Key mismatch") assert.Equal(t, result.ExpiresAt, int64(0), "ExpiresAt mismatch") }) t.Run("network error", func(t *testing.T) { // Use an invalid endpoint that will fail to connect endpoint := "http://localhost:99999/api" result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com", "pass1234") if err == nil { t.Error("error should have been returned for network failure") } assert.Equal(t, result.Key, "", "Key mismatch") assert.Equal(t, result.ExpiresAt, int64(0), "ExpiresAt mismatch") }) } func TestSignOut(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/api/v3/signout" && r.Method == "POST" { w.WriteHeader(http.StatusNoContent) } })) defer ts.Close() commonTs := startCommonTestServer() defer commonTs.Close() correctEndpoint := fmt.Sprintf("%s/api", ts.URL) testClient := NewRateLimitedHTTPClient() t.Run("success", func(t *testing.T) { err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com") if err != nil { t.Errorf("got signout request error: %+v", err.Error()) } }) t.Run("server error", func(t *testing.T) { endpoint := fmt.Sprintf("%s/bad-api", commonTs.URL) err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com") if err == nil { t.Error("error should have been returned") } }) t.Run("accidentally pointing to a catch-all handler", func(t *testing.T) { endpoint := fmt.Sprintf("%s", commonTs.URL) err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com") assert.Equal(t, errors.Cause(err), ErrContentTypeMismatch, "error cause mismatch") }) // Gracefully handle a case where http client was not initialized in the context. t.Run("nil HTTPClient", func(t *testing.T) { err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: correctEndpoint}, "alice@example.com") if err != nil { t.Errorf("got signout request error: %+v", err.Error()) } }) } func TestRateLimitedTransport(t *testing.T) { var requestCount atomic.Int32 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestCount.Add(1) w.WriteHeader(http.StatusOK) })) defer ts.Close() transport := &rateLimitedTransport{ transport: http.DefaultTransport, limiter: rate.NewLimiter(10, 5), } client := &http.Client{Transport: transport} // Make 10 requests start := time.Now() numRequests := 10 for i := range numRequests { req, _ := http.NewRequest("GET", ts.URL, nil) resp, err := client.Do(req) if err != nil { t.Fatalf("Request %d failed: %v", i, err) } resp.Body.Close() } elapsed := time.Since(start) // Burst of 5, then 5 more at 10 req/s = 500ms minimum if elapsed < 500*time.Millisecond { t.Errorf("Rate limit not enforced: 10 requests took %v, expected >= 500ms", elapsed) } assert.Equal(t, int(requestCount.Load()), 10, "request count mismatch") } func TestHTTPError(t *testing.T) { t.Run("IsConflict returns true for 409", func(t *testing.T) { conflictErr := &HTTPError{ StatusCode: 409, Message: "Conflict", } assert.Equal(t, conflictErr.IsConflict(), true, "IsConflict() should return true for 409") notFoundErr := &HTTPError{ StatusCode: 404, Message: "Not Found", } assert.Equal(t, notFoundErr.IsConflict(), false, "IsConflict() should return false for 404") }) }