mirror of
https://github.com/dnote/dnote
synced 2026-03-18 08:19:55 +01:00
Rate limit client (#689)
This commit is contained in:
parent
fd7b2a78b2
commit
e0f68fc8d8
9 changed files with 259 additions and 74 deletions
|
|
@ -33,6 +33,7 @@ import (
|
|||
"github.com/dnote/dnote/pkg/cli/context"
|
||||
"github.com/dnote/dnote/pkg/cli/log"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// ErrInvalidLogin is an error for invalid credentials for login
|
||||
|
|
@ -44,15 +45,66 @@ var ErrContentTypeMismatch = errors.New("content type mismatch")
|
|||
var contentTypeApplicationJSON = "application/json"
|
||||
var contentTypeNone = ""
|
||||
|
||||
// requestOptions contians options for requests
|
||||
// requestOptions contains options for requests
|
||||
type requestOptions struct {
|
||||
HTTPClient *http.Client
|
||||
// ExpectedContentType is the Content-Type that the client is expecting from the server
|
||||
ExpectedContentType *string
|
||||
}
|
||||
|
||||
var defaultRequestOptions = requestOptions{
|
||||
ExpectedContentType: &contentTypeApplicationJSON,
|
||||
const (
|
||||
// clientRateLimitPerSecond is the max requests per second the client will make
|
||||
clientRateLimitPerSecond = 50
|
||||
// clientRateLimitBurst is the burst capacity for rate limiting
|
||||
clientRateLimitBurst = 100
|
||||
)
|
||||
|
||||
// rateLimitedTransport wraps an http.RoundTripper with rate limiting
|
||||
type rateLimitedTransport struct {
|
||||
transport http.RoundTripper
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func (t *rateLimitedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Wait for rate limiter to allow the request
|
||||
if err := t.limiter.Wait(req.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// NewRateLimitedHTTPClient creates an HTTP client with rate limiting
|
||||
func NewRateLimitedHTTPClient() *http.Client {
|
||||
// Calculate interval from rate: 1 second / requests per second
|
||||
interval := time.Second / time.Duration(clientRateLimitPerSecond)
|
||||
|
||||
transport := &rateLimitedTransport{
|
||||
transport: http.DefaultTransport,
|
||||
limiter: rate.NewLimiter(rate.Every(interval), clientRateLimitBurst),
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func getHTTPClient(ctx context.DnoteCtx, options *requestOptions) *http.Client {
|
||||
if options != nil && options.HTTPClient != nil {
|
||||
return options.HTTPClient
|
||||
}
|
||||
|
||||
if ctx.HTTPClient != nil {
|
||||
return ctx.HTTPClient
|
||||
}
|
||||
|
||||
return &http.Client{}
|
||||
}
|
||||
|
||||
func getExpectedContentType(options *requestOptions) string {
|
||||
if options != nil && options.ExpectedContentType != nil {
|
||||
return *options.ExpectedContentType
|
||||
}
|
||||
|
||||
return contentTypeApplicationJSON
|
||||
}
|
||||
|
||||
func getReq(ctx context.DnoteCtx, path, method, body string) (*http.Request, error) {
|
||||
|
|
@ -72,22 +124,6 @@ func getReq(ctx context.DnoteCtx, path, method, body string) (*http.Request, err
|
|||
return req, nil
|
||||
}
|
||||
|
||||
func getHTTPClient(options *requestOptions) http.Client {
|
||||
if options != nil && options.HTTPClient != nil {
|
||||
return *options.HTTPClient
|
||||
}
|
||||
|
||||
return http.Client{}
|
||||
}
|
||||
|
||||
func getExpectedContentType(options *requestOptions) string {
|
||||
if options != nil && options.ExpectedContentType != nil {
|
||||
return *options.ExpectedContentType
|
||||
}
|
||||
|
||||
return contentTypeApplicationJSON
|
||||
}
|
||||
|
||||
// checkRespErr checks if the given http response indicates an error. It returns a boolean indicating
|
||||
// if the response is an error, and a decoded error message.
|
||||
func checkRespErr(res *http.Response) error {
|
||||
|
|
@ -124,7 +160,7 @@ func doReq(ctx context.DnoteCtx, method, path, body string, options *requestOpti
|
|||
|
||||
log.Debug("HTTP request: %+v\n", req)
|
||||
|
||||
hc := getHTTPClient(options)
|
||||
hc := getHTTPClient(ctx, options)
|
||||
res, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return res, errors.Wrap(err, "making http request")
|
||||
|
|
@ -542,15 +578,27 @@ func Signin(ctx context.DnoteCtx, email, password string) (SigninResponse, error
|
|||
|
||||
// Signout deletes a user session on the server side
|
||||
func Signout(ctx context.DnoteCtx, sessionKey string) error {
|
||||
hc := http.Client{
|
||||
// No need to follow redirect
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
// Create a client that shares the transport (and thus rate limiter) from ctx.HTTPClient
|
||||
// but doesn't follow redirects
|
||||
var hc *http.Client
|
||||
if ctx.HTTPClient != nil {
|
||||
hc = &http.Client{
|
||||
Transport: ctx.HTTPClient.Transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
} else {
|
||||
log.Warnf("No HTTP client configured for signout - falling back\n")
|
||||
hc = &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
opts := requestOptions{
|
||||
HTTPClient: &hc,
|
||||
HTTPClient: hc,
|
||||
ExpectedContentType: &contentTypeNone,
|
||||
}
|
||||
_, err := doAuthorizedReq(ctx, "POST", "/v3/signout", "", &opts)
|
||||
|
|
|
|||
|
|
@ -23,12 +23,15 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/assert"
|
||||
"github.com/dnote/dnote/pkg/cli/context"
|
||||
"github.com/dnote/dnote/pkg/cli/testutils"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// startCommonTestServer starts a test HTTP server that simulates a common set of senarios
|
||||
|
|
@ -82,9 +85,10 @@ func TestSignIn(t *testing.T) {
|
|||
defer commonTs.Close()
|
||||
|
||||
correctEndpoint := fmt.Sprintf("%s/api", ts.URL)
|
||||
testClient := NewRateLimitedHTTPClient()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint}, "alice@example.com", "pass1234")
|
||||
result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com", "pass1234")
|
||||
if err != nil {
|
||||
t.Errorf("got signin request error: %+v", err.Error())
|
||||
}
|
||||
|
|
@ -94,7 +98,7 @@ func TestSignIn(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("failure", func(t *testing.T) {
|
||||
result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint}, "alice@example.com", "incorrectpassword")
|
||||
result, err := Signin(context.DnoteCtx{APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com", "incorrectpassword")
|
||||
|
||||
assert.Equal(t, err, ErrInvalidLogin, "err mismatch")
|
||||
assert.Equal(t, result.Key, "", "Key mismatch")
|
||||
|
|
@ -103,7 +107,7 @@ func TestSignIn(t *testing.T) {
|
|||
|
||||
t.Run("server error", func(t *testing.T) {
|
||||
endpoint := fmt.Sprintf("%s/bad-api", ts.URL)
|
||||
result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint}, "alice@example.com", "pass1234")
|
||||
result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com", "pass1234")
|
||||
if err == nil {
|
||||
t.Error("error should have been returned")
|
||||
}
|
||||
|
|
@ -114,7 +118,7 @@ func TestSignIn(t *testing.T) {
|
|||
|
||||
t.Run("accidentally pointing to a catch-all handler", func(t *testing.T) {
|
||||
endpoint := fmt.Sprintf("%s", ts.URL)
|
||||
result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint}, "alice@example.com", "pass1234")
|
||||
result, err := Signin(context.DnoteCtx{APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com", "pass1234")
|
||||
|
||||
assert.Equal(t, errors.Cause(err), ErrContentTypeMismatch, "error cause mismatch")
|
||||
assert.Equal(t, result.Key, "", "Key mismatch")
|
||||
|
|
@ -134,17 +138,18 @@ func TestSignOut(t *testing.T) {
|
|||
defer commonTs.Close()
|
||||
|
||||
correctEndpoint := fmt.Sprintf("%s/api", ts.URL)
|
||||
testClient := NewRateLimitedHTTPClient()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: correctEndpoint}, "alice@example.com")
|
||||
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: correctEndpoint, HTTPClient: testClient}, "alice@example.com")
|
||||
if err != nil {
|
||||
t.Errorf("got signin request error: %+v", err.Error())
|
||||
t.Errorf("got signout request error: %+v", err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server error", func(t *testing.T) {
|
||||
endpoint := fmt.Sprintf("%s/bad-api", commonTs.URL)
|
||||
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint}, "alice@example.com")
|
||||
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com")
|
||||
if err == nil {
|
||||
t.Error("error should have been returned")
|
||||
}
|
||||
|
|
@ -152,8 +157,51 @@ func TestSignOut(t *testing.T) {
|
|||
|
||||
t.Run("accidentally pointing to a catch-all handler", func(t *testing.T) {
|
||||
endpoint := fmt.Sprintf("%s", commonTs.URL)
|
||||
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint}, "alice@example.com")
|
||||
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: endpoint, HTTPClient: testClient}, "alice@example.com")
|
||||
|
||||
assert.Equal(t, errors.Cause(err), ErrContentTypeMismatch, "error cause mismatch")
|
||||
})
|
||||
|
||||
// Gracefully handle a case where http client was not initialized in the context.
|
||||
t.Run("nil HTTPClient", func(t *testing.T) {
|
||||
err := Signout(context.DnoteCtx{SessionKey: "somekey", APIEndpoint: correctEndpoint}, "alice@example.com")
|
||||
if err != nil {
|
||||
t.Errorf("got signout request error: %+v", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimitedTransport(t *testing.T) {
|
||||
var requestCount atomic.Int32
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount.Add(1)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
transport := &rateLimitedTransport{
|
||||
transport: http.DefaultTransport,
|
||||
limiter: rate.NewLimiter(10, 5),
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
// Make 10 requests
|
||||
start := time.Now()
|
||||
numRequests := 10
|
||||
for i := range numRequests {
|
||||
req, _ := http.NewRequest("GET", ts.URL, nil)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request %d failed: %v", i, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Burst of 5, then 5 more at 10 req/s = 500ms minimum
|
||||
if elapsed < 500*time.Millisecond {
|
||||
t.Errorf("Rate limit not enforced: 10 requests took %v, expected >= 500ms", elapsed)
|
||||
}
|
||||
|
||||
assert.Equal(t, int(requestCount.Load()), 10, "request count mismatch")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@
|
|||
package context
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/dnote/dnote/pkg/cli/database"
|
||||
"github.com/dnote/dnote/pkg/clock"
|
||||
)
|
||||
|
|
@ -44,6 +46,7 @@ type DnoteCtx struct {
|
|||
Editor string
|
||||
Clock clock.Clock
|
||||
EnableUpgradeCheck bool
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// Redact replaces private information from the context with a set of
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dnote/dnote/pkg/cli/client"
|
||||
"github.com/dnote/dnote/pkg/cli/config"
|
||||
"github.com/dnote/dnote/pkg/cli/consts"
|
||||
"github.com/dnote/dnote/pkg/cli/context"
|
||||
|
|
@ -159,6 +160,7 @@ func SetupCtx(ctx context.DnoteCtx) (context.DnoteCtx, error) {
|
|||
Editor: cf.Editor,
|
||||
Clock: clock.New(),
|
||||
EnableUpgradeCheck: cf.EnableUpgradeCheck,
|
||||
HTTPClient: client.NewRateLimitedHTTPClient(),
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ func TestMain(m *testing.M) {
|
|||
a.EmailTemplates = mailer.Templates{}
|
||||
a.EmailBackend = &apitest.MockEmailbackendImplementation{}
|
||||
a.DB = serverDb
|
||||
a.WebURL = os.Getenv("WebURL")
|
||||
|
||||
var err error
|
||||
server, err = controllers.NewServer(&a)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,2 @@
|
|||
APP_ENV=DEVELOPMENT
|
||||
|
||||
SmtpUsername=mock-SmtpUsername
|
||||
SmtpPassword=mock-SmtpPassword
|
||||
SmtpHost=mock-SmtpHost
|
||||
SmtpPort=465
|
||||
|
||||
WebURL=http://localhost:3001
|
||||
DisableRegistration=false
|
||||
DBPath=../../dev-server.db
|
||||
|
|
|
|||
|
|
@ -1,9 +1 @@
|
|||
APP_ENV=TEST
|
||||
|
||||
SmtpUsername=mock-SmtpUsername
|
||||
SmtpPassword=mock-SmtpPassword
|
||||
SmtpHost=mock-SmtpHost
|
||||
SmtpPort=465
|
||||
|
||||
WebURL=http://localhost:3001
|
||||
DisableRegistration=false
|
||||
|
|
|
|||
|
|
@ -29,63 +29,81 @@ import (
|
|||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
// serverRateLimitPerSecond is the max requests per second the server will accept per IP
|
||||
serverRateLimitPerSecond = 50
|
||||
// serverRateLimitBurst is the burst capacity for rate limiting
|
||||
serverRateLimitBurst = 100
|
||||
)
|
||||
|
||||
type visitor struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
var visitors = make(map[string]*visitor)
|
||||
var mtx sync.RWMutex
|
||||
|
||||
func init() {
|
||||
go cleanupVisitors()
|
||||
// RateLimiter holds the rate limiting state for visitors
|
||||
type RateLimiter struct {
|
||||
visitors map[string]*visitor
|
||||
mtx sync.RWMutex
|
||||
}
|
||||
|
||||
// addVisitor adds a new visitor to the map and returns a limiter for the visitor
|
||||
func addVisitor(identifier string) *rate.Limiter {
|
||||
// initialize a token bucket
|
||||
limiter := rate.NewLimiter(rate.Every(1*time.Second), 60)
|
||||
// NewRateLimiter creates a new rate limiter instance
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
visitors: make(map[string]*visitor),
|
||||
}
|
||||
go rl.cleanupVisitors()
|
||||
return rl
|
||||
}
|
||||
|
||||
mtx.Lock()
|
||||
visitors[identifier] = &visitor{
|
||||
var defaultLimiter = NewRateLimiter()
|
||||
|
||||
// addVisitor adds a new visitor to the map and returns a limiter for the visitor
|
||||
func (rl *RateLimiter) addVisitor(identifier string) *rate.Limiter {
|
||||
// Calculate interval from rate: 1 second / requests per second
|
||||
interval := time.Second / time.Duration(serverRateLimitPerSecond)
|
||||
limiter := rate.NewLimiter(rate.Every(interval), serverRateLimitBurst)
|
||||
|
||||
rl.mtx.Lock()
|
||||
rl.visitors[identifier] = &visitor{
|
||||
limiter: limiter,
|
||||
lastSeen: time.Now()}
|
||||
mtx.Unlock()
|
||||
rl.mtx.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// getVisitor returns a limiter for a visitor with the given identifier. It
|
||||
// adds the visitor to the map if not seen before.
|
||||
func getVisitor(identifier string) *rate.Limiter {
|
||||
mtx.RLock()
|
||||
v, exists := visitors[identifier]
|
||||
func (rl *RateLimiter) getVisitor(identifier string) *rate.Limiter {
|
||||
rl.mtx.RLock()
|
||||
v, exists := rl.visitors[identifier]
|
||||
|
||||
if !exists {
|
||||
mtx.RUnlock()
|
||||
return addVisitor(identifier)
|
||||
rl.mtx.RUnlock()
|
||||
return rl.addVisitor(identifier)
|
||||
}
|
||||
|
||||
v.lastSeen = time.Now()
|
||||
mtx.RUnlock()
|
||||
rl.mtx.RUnlock()
|
||||
|
||||
return v.limiter
|
||||
}
|
||||
|
||||
// cleanupVisitors deletes visitors that has not been seen in a while from the
|
||||
// map of visitors
|
||||
func cleanupVisitors() {
|
||||
func (rl *RateLimiter) cleanupVisitors() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
mtx.Lock()
|
||||
rl.mtx.Lock()
|
||||
|
||||
for identifier, v := range visitors {
|
||||
for identifier, v := range rl.visitors {
|
||||
if time.Since(v.lastSeen) > 3*time.Minute {
|
||||
delete(visitors, identifier)
|
||||
delete(rl.visitors, identifier)
|
||||
}
|
||||
}
|
||||
|
||||
mtx.Unlock()
|
||||
rl.mtx.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -107,10 +125,10 @@ func lookupIP(r *http.Request) string {
|
|||
}
|
||||
|
||||
// Limit is a middleware to rate limit the handler
|
||||
func Limit(next http.Handler) http.HandlerFunc {
|
||||
func (rl *RateLimiter) Limit(next http.Handler) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
identifier := lookupIP(r)
|
||||
limiter := getVisitor(identifier)
|
||||
limiter := rl.getVisitor(identifier)
|
||||
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
|
|
@ -124,12 +142,12 @@ func Limit(next http.Handler) http.HandlerFunc {
|
|||
})
|
||||
}
|
||||
|
||||
// ApplyLimit applies rate limit conditionally
|
||||
// ApplyLimit applies rate limit conditionally using the global limiter
|
||||
func ApplyLimit(h http.HandlerFunc, rateLimit bool) http.Handler {
|
||||
ret := h
|
||||
|
||||
if rateLimit && os.Getenv("APP_ENV") != "TEST" {
|
||||
ret = Limit(ret)
|
||||
ret = defaultLimiter.Limit(ret)
|
||||
}
|
||||
|
||||
return ret
|
||||
|
|
|
|||
82
pkg/server/middleware/limit_test.go
Normal file
82
pkg/server/middleware/limit_test.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
/* Copyright (C) 2019, 2020, 2021, 2022, 2023, 2024, 2025 Dnote contributors
|
||||
*
|
||||
* This file is part of Dnote.
|
||||
*
|
||||
* Dnote is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* Dnote is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
limiter := NewRateLimiter()
|
||||
middleware := limiter.Limit(handler)
|
||||
|
||||
// Make burst + 5 requests from same IP
|
||||
numRequests := serverRateLimitBurst + 5
|
||||
blockedCount := 0
|
||||
|
||||
for range numRequests {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
middleware.ServeHTTP(w, req)
|
||||
|
||||
if w.Code == http.StatusTooManyRequests {
|
||||
blockedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// At least some requests after burst should be blocked
|
||||
if blockedCount == 0 {
|
||||
t.Error("Expected some requests to be rate limited after burst")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimit_DifferentIPs(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
limiter := NewRateLimiter()
|
||||
middleware := limiter.Limit(handler)
|
||||
|
||||
// Exhaust rate limit for first IP
|
||||
for range serverRateLimitBurst + 5 {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
middleware.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Request from different IP should still succeed
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.2:5678"
|
||||
w := httptest.NewRecorder()
|
||||
middleware.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Request from different IP should succeed, got status %d", w.Code)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue