Add throttler class.

This commit is contained in:
Joachim Bauch 2024-05-14 11:58:01 +02:00
parent 5f18913646
commit 31b8c74d1c
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
2 changed files with 559 additions and 0 deletions

282
throttle.go Normal file
View file

@ -0,0 +1,282 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"errors"
"log"
"sync"
"time"
)
const (
// By default, if more than 10 requests failed in 30 minutes, a bruteforce
// attack is detected and the client will be blocked.
// maxBruteforceAttempts specifies the number of failed requests that may
// happen during "maxBruteforceDurationThreshold" until it is seen as
// "bruteforce" attempt.
maxBruteforceAttempts = 10
// maxBruteforceDurationThreshold specifies the duration during which the number of
// failed requests may not exceed "maxBruteforceAttempts" to be seen as
// "bruteforce" attempt.
maxBruteforceDurationThreshold = 30 * time.Minute
// maxBruteforceAge specifies the age for which failed attempts are remembered.
maxBruteforceAge = 12 * time.Hour
// maxThrottleDelay specifies the maxium time to sleep for failed requests.
maxThrottleDelay = 25 * time.Second
)
var (
ErrBruteforceDetected = errors.New("bruteforce detected")
)
type ThrottleFunc func(ctx context.Context)
type Throttler interface {
Close()
CheckBruteforce(ctx context.Context, client string, action string) (ThrottleFunc, error)
}
type throttleEntry struct {
ts time.Time
}
type memoryThrottler struct {
getNow func() time.Time
doDelay func(context.Context, time.Duration)
mu sync.RWMutex
clients map[string]map[string][]throttleEntry
closer *Closer
}
func NewMemoryThrottler() (Throttler, error) {
result := &memoryThrottler{
getNow: time.Now,
clients: make(map[string]map[string][]throttleEntry),
closer: NewCloser(),
}
result.doDelay = result.delay
go result.housekeeping()
return result, nil
}
func intPow(n, m int) int {
if m == 0 {
return 1
}
result := n
for i := 2; i <= m; i++ {
result *= n
}
return result
}
func (t *memoryThrottler) getEntries(client string, action string) []throttleEntry {
t.mu.RLock()
defer t.mu.RUnlock()
actions := t.clients[client]
if len(actions) == 0 {
return nil
}
entries := actions[action]
return entries
}
func (t *memoryThrottler) setEntries(client string, action string, entries []throttleEntry) {
t.mu.Lock()
defer t.mu.Unlock()
actions := t.clients[client]
if len(actions) == 0 {
if len(entries) == 0 {
return
}
actions = make(map[string][]throttleEntry)
t.clients[client] = actions
}
if len(entries) > 0 {
actions[action] = entries
} else {
delete(actions, action)
if len(actions) == 0 {
delete(t.clients, client)
}
}
}
func (t *memoryThrottler) addEntry(client string, action string, entry throttleEntry) int {
t.mu.Lock()
defer t.mu.Unlock()
actions, found := t.clients[client]
if !found {
t.clients[client] = map[string][]throttleEntry{
action: {
entry,
},
}
return 1
}
entries, found := actions[action]
if !found {
actions[action] = []throttleEntry{
entry,
}
return 1
}
actions[action] = append(entries, entry)
return len(entries) + 1
}
func (t *memoryThrottler) housekeeping() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for !t.closer.IsClosed() {
select {
case now := <-ticker.C:
t.cleanup(now)
case <-t.closer.C:
}
}
}
func (t *memoryThrottler) filterEntries(entries []throttleEntry, now time.Time) []throttleEntry {
start := 0
l := len(entries)
delta := now.Sub(entries[start].ts)
for delta > maxBruteforceAge {
start++
if start == l {
break
}
delta = now.Sub(entries[start].ts)
}
if start == l {
// No entries remaining, client is unknown.
return nil
}
if start > 0 {
entries = append([]throttleEntry{}, entries[start:]...)
}
return entries
}
func (t *memoryThrottler) cleanup(now time.Time) {
t.mu.Lock()
defer t.mu.Unlock()
for client, actions := range t.clients {
for action, entries := range actions {
newEntries := t.filterEntries(entries, now)
if newl := len(newEntries); newl == 0 {
delete(actions, action)
} else if newl != len(entries) {
actions[action] = newEntries
}
}
if len(actions) == 0 {
delete(t.clients, client)
}
}
}
func (t *memoryThrottler) Close() {
t.closer.Close()
}
func (t *memoryThrottler) getDelay(count int) time.Duration {
delay := time.Duration(100*intPow(2, count)) * time.Millisecond
if delay > maxThrottleDelay {
delay = maxThrottleDelay
}
return delay
}
func (t *memoryThrottler) CheckBruteforce(ctx context.Context, client string, action string) (ThrottleFunc, error) {
now := t.getNow()
doThrottle := func(ctx context.Context) {
t.throttle(ctx, client, action, now)
}
entries := t.getEntries(client, action)
l := len(entries)
if l == 0 {
return doThrottle, nil
}
if l >= maxBruteforceAttempts {
delta := now.Sub(entries[l-maxBruteforceAttempts].ts)
if delta <= maxBruteforceDurationThreshold {
log.Printf("Detected bruteforce attempt on \"%s\" from %s", action, client)
return doThrottle, ErrBruteforceDetected
}
}
// Remove old entries.
newEntries := t.filterEntries(entries, now)
if newl := len(newEntries); newl == 0 {
t.setEntries(client, action, nil)
return doThrottle, nil
} else if newl != l {
t.setEntries(client, action, newEntries)
}
return doThrottle, nil
}
func (t *memoryThrottler) throttle(ctx context.Context, client string, action string, now time.Time) {
entry := throttleEntry{
ts: now,
}
count := t.addEntry(client, action, entry)
delay := t.getDelay(count - 1)
log.Printf("Failed attempt on \"%s\" from %s, throttling by %s", action, client, delay)
t.doDelay(ctx, delay)
}
func (t *memoryThrottler) delay(ctx context.Context, duration time.Duration) {
c, cancel := context.WithTimeout(ctx, duration)
defer cancel()
<-c.Done()
}

277
throttle_test.go Normal file
View file

@ -0,0 +1,277 @@
/**
* Standalone signaling server for the Nextcloud Spreed app.
* Copyright (C) 2024 struktur AG
*
* @author Joachim Bauch <bauch@struktur.de>
*
* @license GNU AGPL version 3 or any later version
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
import (
"context"
"testing"
"time"
)
func newMemoryThrottlerForTest(t *testing.T) *memoryThrottler {
t.Helper()
result, err := NewMemoryThrottler()
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
result.Close()
})
return result.(*memoryThrottler)
}
type throttlerTiming struct {
t *testing.T
now time.Time
expectedSleep time.Duration
}
func (t *throttlerTiming) getNow() time.Time {
return t.now
}
func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) {
t.t.Helper()
if duration != t.expectedSleep {
t.t.Errorf("expected sleep %s, got %s", t.expectedSleep, duration)
}
}
func TestThrottler(t *testing.T) {
timing := &throttlerTiming{
t: t,
now: time.Now(),
}
th := newMemoryThrottlerForTest(t)
th.getNow = timing.getNow
th.doDelay = timing.doDelay
ctx := context.Background()
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 200 * time.Millisecond
throttle2(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle3(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle4(ctx)
}
func TestThrottler_Bruteforce(t *testing.T) {
timing := &throttlerTiming{
t: t,
now: time.Now(),
}
th := newMemoryThrottlerForTest(t)
th.getNow = timing.getNow
th.doDelay = timing.doDelay
ctx := context.Background()
for i := 0; i < maxBruteforceAttempts; i++ {
timing.now = timing.now.Add(time.Millisecond)
throttle, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
if i == 0 {
timing.expectedSleep = 100 * time.Millisecond
} else {
timing.expectedSleep *= 2
if timing.expectedSleep > maxThrottleDelay {
timing.expectedSleep = maxThrottleDelay
}
}
throttle(ctx)
}
timing.now = timing.now.Add(time.Millisecond)
if _, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1"); err == nil {
t.Error("expected bruteforce error")
} else if err != ErrBruteforceDetected {
t.Errorf("expected error %s, got %s", ErrBruteforceDetected, err)
}
}
func TestThrottler_Cleanup(t *testing.T) {
timing := &throttlerTiming{
t: t,
now: time.Now(),
}
th := newMemoryThrottlerForTest(t)
th.getNow = timing.getNow
th.doDelay = timing.doDelay
ctx := context.Background()
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle2(ctx)
timing.now = timing.now.Add(time.Hour)
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle3(ctx)
throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 200 * time.Millisecond
throttle4(ctx)
timing.now = timing.now.Add(-time.Hour).Add(maxBruteforceAge).Add(time.Second)
th.cleanup(timing.now)
if entries := th.getEntries("192.168.0.1", "action1"); len(entries) != 1 {
t.Errorf("should have removed one entry, got %+v", entries)
}
if entries := th.getEntries("192.168.0.1", "action2"); len(entries) != 1 {
t.Errorf("should have kept entry, got %+v", entries)
}
th.mu.RLock()
if _, found := th.clients["192.168.0.2"]; found {
t.Error("should have removed client \"192.168.0.2\"")
}
th.mu.RUnlock()
throttle5, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 200 * time.Millisecond
throttle5(ctx)
}
func TestThrottler_ExpirePartial(t *testing.T) {
timing := &throttlerTiming{
t: t,
now: time.Now(),
}
th := newMemoryThrottlerForTest(t)
th.getNow = timing.getNow
th.doDelay = timing.doDelay
ctx := context.Background()
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
timing.now = timing.now.Add(time.Minute)
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 200 * time.Millisecond
throttle2(ctx)
timing.now = timing.now.Add(maxBruteforceAge).Add(-time.Minute + time.Second)
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 200 * time.Millisecond
throttle3(ctx)
}
func TestThrottler_ExpireAll(t *testing.T) {
timing := &throttlerTiming{
t: t,
now: time.Now(),
}
th := newMemoryThrottlerForTest(t)
th.getNow = timing.getNow
th.doDelay = timing.doDelay
ctx := context.Background()
throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle1(ctx)
timing.now = timing.now.Add(time.Millisecond)
throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 200 * time.Millisecond
throttle2(ctx)
timing.now = timing.now.Add(maxBruteforceAge).Add(time.Second)
throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1")
if err != nil {
t.Error(err)
}
timing.expectedSleep = 100 * time.Millisecond
throttle3(ctx)
}