dnote/pkg/server/middleware/limit.go
Sung e72322f847
Simplify email backend and remove --appEnv (#710)
* Improve logging

* Remove AppEnv

* Simplify email backend
2025-11-01 00:54:27 -07:00

150 lines
3.6 KiB
Go

/* Copyright 2025 Dnote Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package middleware
import (
"net/http"
"strings"
"sync"
"time"
"github.com/dnote/dnote/pkg/server/log"
"golang.org/x/time/rate"
)
const (
// serverRateLimitPerSecond is the max requests per second the server will accept per IP
serverRateLimitPerSecond = 50
// serverRateLimitBurst is the burst capacity for rate limiting
serverRateLimitBurst = 100
)
type visitor struct {
limiter *rate.Limiter
lastSeen time.Time
}
// RateLimiter holds the rate limiting state for visitors
type RateLimiter struct {
visitors map[string]*visitor
mtx sync.RWMutex
}
// NewRateLimiter creates a new rate limiter instance
func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
visitors: make(map[string]*visitor),
}
go rl.cleanupVisitors()
return rl
}
var defaultLimiter = NewRateLimiter()
// addVisitor adds a new visitor to the map and returns a limiter for the visitor
func (rl *RateLimiter) addVisitor(identifier string) *rate.Limiter {
// Calculate interval from rate: 1 second / requests per second
interval := time.Second / time.Duration(serverRateLimitPerSecond)
limiter := rate.NewLimiter(rate.Every(interval), serverRateLimitBurst)
rl.mtx.Lock()
rl.visitors[identifier] = &visitor{
limiter: limiter,
lastSeen: time.Now()}
rl.mtx.Unlock()
return limiter
}
// getVisitor returns a limiter for a visitor with the given identifier. It
// adds the visitor to the map if not seen before.
func (rl *RateLimiter) getVisitor(identifier string) *rate.Limiter {
rl.mtx.RLock()
v, exists := rl.visitors[identifier]
if !exists {
rl.mtx.RUnlock()
return rl.addVisitor(identifier)
}
v.lastSeen = time.Now()
rl.mtx.RUnlock()
return v.limiter
}
// cleanupVisitors deletes visitors that has not been seen in a while from the
// map of visitors
func (rl *RateLimiter) cleanupVisitors() {
for {
time.Sleep(time.Minute)
rl.mtx.Lock()
for identifier, v := range rl.visitors {
if time.Since(v.lastSeen) > 3*time.Minute {
delete(rl.visitors, identifier)
}
}
rl.mtx.Unlock()
}
}
// lookupIP returns the request's IP
func lookupIP(r *http.Request) string {
realIP := r.Header.Get("X-Real-IP")
forwardedFor := r.Header.Get("X-Forwarded-For")
if forwardedFor != "" {
parts := strings.Split(forwardedFor, ",")
return parts[0]
}
if realIP != "" {
return realIP
}
return r.RemoteAddr
}
// Limit is a middleware to rate limit the handler
func (rl *RateLimiter) Limit(next http.Handler) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
identifier := lookupIP(r)
limiter := rl.getVisitor(identifier)
if !limiter.Allow() {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
log.WithFields(log.Fields{
"ip": identifier,
}).Warn("Too many requests")
return
}
next.ServeHTTP(w, r)
})
}
// ApplyLimit applies rate limit conditionally using the global limiter
func ApplyLimit(h http.HandlerFunc, rateLimit bool) http.Handler {
ret := h
if rateLimit {
ret = defaultLimiter.Limit(ret)
}
return ret
}