mirror of
https://github.com/dnote/dnote
synced 2026-03-14 22:45:50 +01:00
150 lines
3.6 KiB
Go
150 lines
3.6 KiB
Go
/* Copyright 2025 Dnote Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/dnote/dnote/pkg/server/log"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
const (
|
|
// serverRateLimitPerSecond is the max requests per second the server will accept per IP
|
|
serverRateLimitPerSecond = 50
|
|
// serverRateLimitBurst is the burst capacity for rate limiting
|
|
serverRateLimitBurst = 100
|
|
)
|
|
|
|
type visitor struct {
|
|
limiter *rate.Limiter
|
|
lastSeen time.Time
|
|
}
|
|
|
|
// RateLimiter holds the rate limiting state for visitors
|
|
type RateLimiter struct {
|
|
visitors map[string]*visitor
|
|
mtx sync.RWMutex
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter instance
|
|
func NewRateLimiter() *RateLimiter {
|
|
rl := &RateLimiter{
|
|
visitors: make(map[string]*visitor),
|
|
}
|
|
go rl.cleanupVisitors()
|
|
return rl
|
|
}
|
|
|
|
var defaultLimiter = NewRateLimiter()
|
|
|
|
// addVisitor adds a new visitor to the map and returns a limiter for the visitor
|
|
func (rl *RateLimiter) addVisitor(identifier string) *rate.Limiter {
|
|
// Calculate interval from rate: 1 second / requests per second
|
|
interval := time.Second / time.Duration(serverRateLimitPerSecond)
|
|
limiter := rate.NewLimiter(rate.Every(interval), serverRateLimitBurst)
|
|
|
|
rl.mtx.Lock()
|
|
rl.visitors[identifier] = &visitor{
|
|
limiter: limiter,
|
|
lastSeen: time.Now()}
|
|
rl.mtx.Unlock()
|
|
|
|
return limiter
|
|
}
|
|
|
|
// getVisitor returns a limiter for a visitor with the given identifier. It
|
|
// adds the visitor to the map if not seen before.
|
|
func (rl *RateLimiter) getVisitor(identifier string) *rate.Limiter {
|
|
rl.mtx.RLock()
|
|
v, exists := rl.visitors[identifier]
|
|
|
|
if !exists {
|
|
rl.mtx.RUnlock()
|
|
return rl.addVisitor(identifier)
|
|
}
|
|
|
|
v.lastSeen = time.Now()
|
|
rl.mtx.RUnlock()
|
|
|
|
return v.limiter
|
|
}
|
|
|
|
// cleanupVisitors deletes visitors that has not been seen in a while from the
|
|
// map of visitors
|
|
func (rl *RateLimiter) cleanupVisitors() {
|
|
for {
|
|
time.Sleep(time.Minute)
|
|
rl.mtx.Lock()
|
|
|
|
for identifier, v := range rl.visitors {
|
|
if time.Since(v.lastSeen) > 3*time.Minute {
|
|
delete(rl.visitors, identifier)
|
|
}
|
|
}
|
|
|
|
rl.mtx.Unlock()
|
|
}
|
|
}
|
|
|
|
// lookupIP returns the request's IP
|
|
func lookupIP(r *http.Request) string {
|
|
realIP := r.Header.Get("X-Real-IP")
|
|
forwardedFor := r.Header.Get("X-Forwarded-For")
|
|
|
|
if forwardedFor != "" {
|
|
parts := strings.Split(forwardedFor, ",")
|
|
return parts[0]
|
|
}
|
|
|
|
if realIP != "" {
|
|
return realIP
|
|
}
|
|
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
// Limit is a middleware to rate limit the handler
|
|
func (rl *RateLimiter) Limit(next http.Handler) http.HandlerFunc {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
identifier := lookupIP(r)
|
|
limiter := rl.getVisitor(identifier)
|
|
|
|
if !limiter.Allow() {
|
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
|
log.WithFields(log.Fields{
|
|
"ip": identifier,
|
|
}).Warn("Too many requests")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// ApplyLimit applies rate limit conditionally using the global limiter
|
|
func ApplyLimit(h http.HandlerFunc, rateLimit bool) http.Handler {
|
|
ret := h
|
|
|
|
if rateLimit {
|
|
ret = defaultLimiter.Limit(ret)
|
|
}
|
|
|
|
return ret
|
|
}
|