Rate limit client (#689)

This commit is contained in:
Sung 2025-10-11 16:14:20 -07:00 committed by GitHub
commit e0f68fc8d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 259 additions and 74 deletions

View file

@ -33,6 +33,7 @@ import (
"github.com/dnote/dnote/pkg/cli/context"
"github.com/dnote/dnote/pkg/cli/log"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
// ErrInvalidLogin is an error for invalid credentials for login
@ -44,15 +45,66 @@ var ErrContentTypeMismatch = errors.New("content type mismatch")
var contentTypeApplicationJSON = "application/json"
var contentTypeNone = ""
// requestOptions contians options for requests
// requestOptions contains options for requests
type requestOptions struct {
HTTPClient *http.Client
// ExpectedContentType is the Content-Type that the client is expecting from the server
ExpectedContentType *string
}
var defaultRequestOptions = requestOptions{
ExpectedContentType: &contentTypeApplicationJSON,
const (
// clientRateLimitPerSecond is the max requests per second the client will make
clientRateLimitPerSecond = 50
// clientRateLimitBurst is the burst capacity for rate limiting
clientRateLimitBurst = 100
)
// rateLimitedTransport wraps an http.RoundTripper with rate limiting
type rateLimitedTransport struct {
transport http.RoundTripper
limiter *rate.Limiter
}
func (t *rateLimitedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Wait for rate limiter to allow the request
if err := t.limiter.Wait(req.Context()); err != nil {
return nil, err
}
return t.transport.RoundTrip(req)
}
// NewRateLimitedHTTPClient creates an HTTP client with rate limiting
func NewRateLimitedHTTPClient() *http.Client {
// Calculate interval from rate: 1 second / requests per second
interval := time.Second / time.Duration(clientRateLimitPerSecond)
transport := &rateLimitedTransport{
transport: http.DefaultTransport,
limiter: rate.NewLimiter(rate.Every(interval), clientRateLimitBurst),
}
return &http.Client{
Transport: transport,
}
}
func getHTTPClient(ctx context.DnoteCtx, options *requestOptions) *http.Client {
if options != nil && options.HTTPClient != nil {
return options.HTTPClient
}
if ctx.HTTPClient != nil {
return ctx.HTTPClient
}
return &http.Client{}
}
func getExpectedContentType(options *requestOptions) string {
if options != nil && options.ExpectedContentType != nil {
return *options.ExpectedContentType
}
return contentTypeApplicationJSON
}
func getReq(ctx context.DnoteCtx, path, method, body string) (*http.Request, error) {
@ -72,22 +124,6 @@ func getReq(ctx context.DnoteCtx, path, method, body string) (*http.Request, err
return req, nil
}
func getHTTPClient(options *requestOptions) http.Client {
if options != nil && options.HTTPClient != nil {
return *options.HTTPClient
}
return http.Client{}
}
func getExpectedContentType(options *requestOptions) string {
if options != nil && options.ExpectedContentType != nil {
return *options.ExpectedContentType
}
return contentTypeApplicationJSON
}
// checkRespErr checks if the given http response indicates an error. It returns a boolean indicating
// if the response is an error, and a decoded error message.
func checkRespErr(res *http.Response) error {
@ -124,7 +160,7 @@ func doReq(ctx context.DnoteCtx, method, path, body string, options *requestOpti
log.Debug("HTTP request: %+v\n", req)
hc := getHTTPClient(options)
hc := getHTTPClient(ctx, options)
res, err := hc.Do(req)
if err != nil {
return res, errors.Wrap(err, "making http request")
@ -542,15 +578,27 @@ func Signin(ctx context.DnoteCtx, email, password string) (SigninResponse, error
// Signout deletes a user session on the server side
func Signout(ctx context.DnoteCtx, sessionKey string) error {
hc := http.Client{
// No need to follow redirect
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
// Create a client that shares the transport (and thus rate limiter) from ctx.HTTPClient
// but doesn't follow redirects
var hc *http.Client
if ctx.HTTPClient != nil {
hc = &http.Client{
Transport: ctx.HTTPClient.Transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
} else {
log.Warnf("No HTTP client configured for signout - falling back\n")
hc = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
opts := requestOptions{
HTTPClient: &hc,
HTTPClient: hc,
ExpectedContentType: &contentTypeNone,
}
_, err := doAuthorizedReq(ctx, "POST", "/v3/signout", "", &opts)

View file

@ -23,12 +23,15 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/dnote/dnote/pkg/assert"
"github.com/dnote/dnote/pkg/cli/context"
"github.com/dnote/dnote/pkg/cli/testutils"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
// startCommonTestServer starts a test HTTP server that simulates a common set of senarios
@ -82,9 +85,10 @@ func TestSignIn(t *testing.T) {
defer commonTs.Close()
correctEndpoint := fmt.Sprintf("%s/api", ts.URL)
testClient := NewRateLimitedHTTPClient()
t.Run("success", func(t *testing.T) {
result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint}, "alice@example.com", "pass1234")
result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com", "pass1234")
if err != nil {
t.Errorf("got signin request error: %+v", err.Error())
}
@ -94,7 +98,7 @@ func TestSignIn(t *testing.T) {
})
t.Run("failure", func(t *testing.T) {
result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint}, "alice@example.com", "incorrectpassword")
result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com", "incorrectpassword")
assert.Equal(t, err, ErrInvalidLogin, "err mismatch")
assert.Equal(t, result.Key, "", "Key mismatch")
@ -103,7 +107,7 @@ func TestSignIn(t *testing.T) {
t.Run("server error", func(t *testing.T) {
endpoint := fmt.Sprintf("%s/bad-api", ts.URL)
result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint}, "alice@example.com", "pass1234")
result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com", "pass1234")
if err == nil {
t.Error("error should have been returned")
}
@ -114,7 +118,7 @@ func TestSignIn(t *testing.T) {
t.Run("accidentally pointing to a catch-all handler", func(t *testing.T) {
endpoint := fmt.Sprintf("%s", ts.URL)
result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint}, "alice@example.com", "pass1234")
result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com", "pass1234")
assert.Equal(t, errors.Cause(err), ErrContentTypeMismatch, "error cause mismatch")
assert.Equal(t, result.Key, "", "Key mismatch")
@ -134,17 +138,18 @@ func TestSignOut(t *testing.T) {
defer commonTs.Close()
correctEndpoint := fmt.Sprintf("%s/api", ts.URL)
testClient := NewRateLimitedHTTPClient()
t.Run("success", func(t *testing.T) {
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: correctEndpoint}, "alice@example.com")
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com")
if err != nil {
t.Errorf("got signin request error: %+v", err.Error())
t.Errorf("got signout request error: %+v", err.Error())
}
})
t.Run("server error", func(t *testing.T) {
endpoint := fmt.Sprintf("%s/bad-api", commonTs.URL)
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint}, "alice@example.com")
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com")
if err == nil {
t.Error("error should have been returned")
}
@ -152,8 +157,51 @@ func TestSignOut(t *testing.T) {
t.Run("accidentally pointing to a catch-all handler", func(t *testing.T) {
endpoint := fmt.Sprintf("%s", commonTs.URL)
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint}, "alice@example.com")
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com")
assert.Equal(t, errors.Cause(err), ErrContentTypeMismatch, "error cause mismatch")
})
// Gracefully handle a case where http client was not initialized in the context.
t.Run("nil HTTPClient", func(t *testing.T) {
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: correctEndpoint}, "alice@example.com")
if err != nil {
t.Errorf("got signout request error: %+v", err.Error())
}
})
}
func TestRateLimitedTransport(t *testing.T) {
var requestCount atomic.Int32
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
transport := &rateLimitedTransport{
transport: http.DefaultTransport,
limiter: rate.NewLimiter(10, 5),
}
client := &http.Client{Transport: transport}
// Make 10 requests
start := time.Now()
numRequests := 10
for i := range numRequests {
req, _ := http.NewRequest("GET", ts.URL, nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
resp.Body.Close()
}
elapsed := time.Since(start)
// Burst of 5, then 5 more at 10 req/s = 500ms minimum
if elapsed < 500*time.Millisecond {
t.Errorf("Rate limit not enforced: 10 requests took %v, expected >= 500ms", elapsed)
}
assert.Equal(t, int(requestCount.Load()), 10, "request count mismatch")
}

View file

@ -20,6 +20,8 @@
package context
import (
"net/http"
"github.com/dnote/dnote/pkg/cli/database"
"github.com/dnote/dnote/pkg/clock"
)
@ -44,6 +46,7 @@ type DnoteCtx struct {
Editor string
Clock clock.Clock
EnableUpgradeCheck bool
HTTPClient *http.Client
}
// Redact replaces private information from the context with a set of

View file

@ -28,6 +28,7 @@ import (
"strconv"
"time"
"github.com/dnote/dnote/pkg/cli/client"
"github.com/dnote/dnote/pkg/cli/config"
"github.com/dnote/dnote/pkg/cli/consts"
"github.com/dnote/dnote/pkg/cli/context"
@ -159,6 +160,7 @@ func SetupCtx(ctx context.DnoteCtx) (context.DnoteCtx, error) {
Editor: cf.Editor,
Clock: clock.New(),
EnableUpgradeCheck: cf.EnableUpgradeCheck,
HTTPClient: client.NewRateLimitedHTTPClient(),
}
return ret, nil

View file

@ -95,7 +95,6 @@ func TestMain(m *testing.M) {
a.EmailTemplates = mailer.Templates{}
a.EmailBackend = &apitest.MockEmailbackendImplementation{}
a.DB = serverDb
a.WebURL = os.Getenv("WebURL")
var err error
server, err = controllers.NewServer(&a)

View file

@ -1,9 +1,2 @@
APP_ENV=DEVELOPMENT
SmtpUsername=mock-SmtpUsername
SmtpPassword=mock-SmtpPassword
SmtpHost=mock-SmtpHost
SmtpPort=465
WebURL=http://localhost:3001
DisableRegistration=false
DBPath=../../dev-server.db

View file

@ -1,9 +1 @@
APP_ENV=TEST
SmtpUsername=mock-SmtpUsername
SmtpPassword=mock-SmtpPassword
SmtpHost=mock-SmtpHost
SmtpPort=465
WebURL=http://localhost:3001
DisableRegistration=false

View file

@ -29,63 +29,81 @@ import (
"golang.org/x/time/rate"
)
const (
// serverRateLimitPerSecond is the max requests per second the server will accept per IP
serverRateLimitPerSecond = 50
// serverRateLimitBurst is the burst capacity for rate limiting
serverRateLimitBurst = 100
)
type visitor struct {
limiter *rate.Limiter
lastSeen time.Time
}
var visitors = make(map[string]*visitor)
var mtx sync.RWMutex
func init() {
go cleanupVisitors()
// RateLimiter holds the rate limiting state for visitors
type RateLimiter struct {
visitors map[string]*visitor
mtx sync.RWMutex
}
// addVisitor adds a new visitor to the map and returns a limiter for the visitor
func addVisitor(identifier string) *rate.Limiter {
// initialize a token bucket
limiter := rate.NewLimiter(rate.Every(1*time.Second), 60)
// NewRateLimiter creates a new rate limiter instance
func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
visitors: make(map[string]*visitor),
}
go rl.cleanupVisitors()
return rl
}
mtx.Lock()
visitors[identifier] = &visitor{
var defaultLimiter = NewRateLimiter()
// addVisitor adds a new visitor to the map and returns a limiter for the visitor
func (rl *RateLimiter) addVisitor(identifier string) *rate.Limiter {
// Calculate interval from rate: 1 second / requests per second
interval := time.Second / time.Duration(serverRateLimitPerSecond)
limiter := rate.NewLimiter(rate.Every(interval), serverRateLimitBurst)
rl.mtx.Lock()
rl.visitors[identifier] = &visitor{
limiter: limiter,
lastSeen: time.Now()}
mtx.Unlock()
rl.mtx.Unlock()
return limiter
}
// getVisitor returns a limiter for a visitor with the given identifier. It
// adds the visitor to the map if not seen before.
func getVisitor(identifier string) *rate.Limiter {
mtx.RLock()
v, exists := visitors[identifier]
func (rl *RateLimiter) getVisitor(identifier string) *rate.Limiter {
rl.mtx.RLock()
v, exists := rl.visitors[identifier]
if !exists {
mtx.RUnlock()
return addVisitor(identifier)
rl.mtx.RUnlock()
return rl.addVisitor(identifier)
}
v.lastSeen = time.Now()
mtx.RUnlock()
rl.mtx.RUnlock()
return v.limiter
}
// cleanupVisitors deletes visitors that has not been seen in a while from the
// map of visitors
func cleanupVisitors() {
func (rl *RateLimiter) cleanupVisitors() {
for {
time.Sleep(time.Minute)
mtx.Lock()
rl.mtx.Lock()
for identifier, v := range visitors {
for identifier, v := range rl.visitors {
if time.Since(v.lastSeen) > 3*time.Minute {
delete(visitors, identifier)
delete(rl.visitors, identifier)
}
}
mtx.Unlock()
rl.mtx.Unlock()
}
}
@ -107,10 +125,10 @@ func lookupIP(r *http.Request) string {
}
// Limit is a middleware to rate limit the handler
func Limit(next http.Handler) http.HandlerFunc {
func (rl *RateLimiter) Limit(next http.Handler) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
identifier := lookupIP(r)
limiter := getVisitor(identifier)
limiter := rl.getVisitor(identifier)
if !limiter.Allow() {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
@ -124,12 +142,12 @@ func Limit(next http.Handler) http.HandlerFunc {
})
}
// ApplyLimit applies rate limit conditionally
// ApplyLimit applies rate limit conditionally using the global limiter
func ApplyLimit(h http.HandlerFunc, rateLimit bool) http.Handler {
ret := h
if rateLimit && os.Getenv("APP_ENV") != "TEST" {
ret = Limit(ret)
ret = defaultLimiter.Limit(ret)
}
return ret

View file

@ -0,0 +1,82 @@
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestLimit(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
limiter := NewRateLimiter()
middleware := limiter.Limit(handler)
// Make burst + 5 requests from same IP
numRequests := serverRateLimitBurst + 5
blockedCount := 0
for range numRequests {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)
if w.Code == http.StatusTooManyRequests {
blockedCount++
}
}
// At least some requests after burst should be blocked
if blockedCount == 0 {
t.Error("Expected some requests to be rate limited after burst")
}
}
func TestLimit_DifferentIPs(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
limiter := NewRateLimiter()
middleware := limiter.Limit(handler)
// Exhaust rate limit for first IP
for range serverRateLimitBurst + 5 {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)
}
// Request from different IP should still succeed
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.2:5678"
w := httptest.NewRecorder()
middleware.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request from different IP should succeed, got status %d", w.Code)
}
}