dnote/pkg/server/middleware/limit.go
2022-04-17 10:47:43 +10:00

136 lines
3 KiB
Go

/* Copyright (C) 2019, 2020, 2021 Monomax Software Pty Ltd
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package middleware
import (
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/dnote/dnote/pkg/server/log"
"golang.org/x/time/rate"
)
type visitor struct {
limiter *rate.Limiter
lastSeen time.Time
}
var visitors = make(map[string]*visitor)
var mtx sync.RWMutex
func init() {
go cleanupVisitors()
}
// addVisitor adds a new visitor to the map and returns a limiter for the visitor
func addVisitor(identifier string) *rate.Limiter {
// initialize a token bucket
limiter := rate.NewLimiter(rate.Every(1*time.Second), 60)
mtx.Lock()
visitors[identifier] = &visitor{
limiter: limiter,
lastSeen: time.Now()}
mtx.Unlock()
return limiter
}
// getVisitor returns a limiter for a visitor with the given identifier. It
// adds the visitor to the map if not seen before.
func getVisitor(identifier string) *rate.Limiter {
mtx.RLock()
v, exists := visitors[identifier]
if !exists {
mtx.RUnlock()
return addVisitor(identifier)
}
v.lastSeen = time.Now()
mtx.RUnlock()
return v.limiter
}
// cleanupVisitors deletes visitors that has not been seen in a while from the
// map of visitors
func cleanupVisitors() {
for {
time.Sleep(time.Minute)
mtx.Lock()
for identifier, v := range visitors {
if time.Now().Sub(v.lastSeen) > 3*time.Minute {
delete(visitors, identifier)
}
}
mtx.Unlock()
}
}
// lookupIP returns the request's IP
func lookupIP(r *http.Request) string {
realIP := r.Header.Get("X-Real-IP")
forwardedFor := r.Header.Get("X-Forwarded-For")
if forwardedFor != "" {
parts := strings.Split(forwardedFor, ",")
return parts[0]
}
if realIP != "" {
return realIP
}
return r.RemoteAddr
}
// Limit is a middleware to rate limit the handler
func Limit(next http.Handler) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
identifier := lookupIP(r)
limiter := getVisitor(identifier)
if !limiter.Allow() {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
log.WithFields(log.Fields{
"ip": identifier,
}).Warn("Too many requests")
return
}
next.ServeHTTP(w, r)
})
}
// ApplyLimit applies rate limit conditionally
func ApplyLimit(h http.HandlerFunc, rateLimit bool) http.Handler {
ret := h
if rateLimit && os.Getenv("GO_ENV") != "TEST" {
ret = Limit(ret)
}
return ret
}