From 31b8c74d1cc4557f1d7c6c0e1d0fedb678c047e3 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 14 May 2024 11:58:01 +0200 Subject: [PATCH 1/5] Add throttler class. --- throttle.go | 282 +++++++++++++++++++++++++++++++++++++++++++++++ throttle_test.go | 277 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 559 insertions(+) create mode 100644 throttle.go create mode 100644 throttle_test.go diff --git a/throttle.go b/throttle.go new file mode 100644 index 0000000..af0abde --- /dev/null +++ b/throttle.go @@ -0,0 +1,282 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "errors" + "log" + "sync" + "time" +) + +const ( + // By default, if more than 10 requests failed in 30 minutes, a bruteforce + // attack is detected and the client will be blocked. + + // maxBruteforceAttempts specifies the number of failed requests that may + // happen during "maxBruteforceDurationThreshold" until it is seen as + // "bruteforce" attempt. + maxBruteforceAttempts = 10 + + // maxBruteforceDurationThreshold specifies the duration during which the number of + // failed requests may not exceed "maxBruteforceAttempts" to be seen as + // "bruteforce" attempt. + maxBruteforceDurationThreshold = 30 * time.Minute + + // maxBruteforceAge specifies the age for which failed attempts are remembered. + maxBruteforceAge = 12 * time.Hour + + // maxThrottleDelay specifies the maxium time to sleep for failed requests. + maxThrottleDelay = 25 * time.Second +) + +var ( + ErrBruteforceDetected = errors.New("bruteforce detected") +) + +type ThrottleFunc func(ctx context.Context) + +type Throttler interface { + Close() + + CheckBruteforce(ctx context.Context, client string, action string) (ThrottleFunc, error) +} + +type throttleEntry struct { + ts time.Time +} + +type memoryThrottler struct { + getNow func() time.Time + doDelay func(context.Context, time.Duration) + + mu sync.RWMutex + clients map[string]map[string][]throttleEntry + + closer *Closer +} + +func NewMemoryThrottler() (Throttler, error) { + result := &memoryThrottler{ + getNow: time.Now, + + clients: make(map[string]map[string][]throttleEntry), + + closer: NewCloser(), + } + result.doDelay = result.delay + go result.housekeeping() + return result, nil +} + +func intPow(n, m int) int { + if m == 0 { + return 1 + } + + result := n + for i := 2; i <= m; i++ { + result *= n + } + return result +} + +func (t *memoryThrottler) getEntries(client string, action string) []throttleEntry { + t.mu.RLock() + defer t.mu.RUnlock() + + actions := t.clients[client] + if len(actions) == 0 { + return nil + } + + entries := actions[action] + return entries +} + +func (t *memoryThrottler) setEntries(client string, action string, entries []throttleEntry) { + t.mu.Lock() + defer t.mu.Unlock() + + actions := t.clients[client] + if len(actions) == 0 { + if len(entries) == 0 { + return + } + + actions = make(map[string][]throttleEntry) + t.clients[client] = actions + } + + if len(entries) > 0 { + actions[action] = entries + } else { + delete(actions, action) + if len(actions) == 0 { + delete(t.clients, client) + } + } +} + +func (t *memoryThrottler) addEntry(client string, action string, entry throttleEntry) int { + t.mu.Lock() + defer t.mu.Unlock() + + actions, found := t.clients[client] + if !found { + t.clients[client] = map[string][]throttleEntry{ + action: { + entry, + }, + } + return 1 + } + + entries, found := actions[action] + if !found { + actions[action] = []throttleEntry{ + entry, + } + return 1 + } + + actions[action] = append(entries, entry) + return len(entries) + 1 +} + +func (t *memoryThrottler) housekeeping() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for !t.closer.IsClosed() { + select { + case now := <-ticker.C: + t.cleanup(now) + case <-t.closer.C: + } + } +} + +func (t *memoryThrottler) filterEntries(entries []throttleEntry, now time.Time) []throttleEntry { + start := 0 + l := len(entries) + delta := now.Sub(entries[start].ts) + for delta > maxBruteforceAge { + start++ + if start == l { + break + } + delta = now.Sub(entries[start].ts) + } + + if start == l { + // No entries remaining, client is unknown. + return nil + } + + if start > 0 { + entries = append([]throttleEntry{}, entries[start:]...) + } + return entries +} + +func (t *memoryThrottler) cleanup(now time.Time) { + t.mu.Lock() + defer t.mu.Unlock() + + for client, actions := range t.clients { + for action, entries := range actions { + newEntries := t.filterEntries(entries, now) + if newl := len(newEntries); newl == 0 { + delete(actions, action) + } else if newl != len(entries) { + actions[action] = newEntries + } + } + + if len(actions) == 0 { + delete(t.clients, client) + } + } +} + +func (t *memoryThrottler) Close() { + t.closer.Close() +} + +func (t *memoryThrottler) getDelay(count int) time.Duration { + delay := time.Duration(100*intPow(2, count)) * time.Millisecond + if delay > maxThrottleDelay { + delay = maxThrottleDelay + } + return delay +} + +func (t *memoryThrottler) CheckBruteforce(ctx context.Context, client string, action string) (ThrottleFunc, error) { + now := t.getNow() + doThrottle := func(ctx context.Context) { + t.throttle(ctx, client, action, now) + } + + entries := t.getEntries(client, action) + l := len(entries) + if l == 0 { + return doThrottle, nil + } + + if l >= maxBruteforceAttempts { + delta := now.Sub(entries[l-maxBruteforceAttempts].ts) + if delta <= maxBruteforceDurationThreshold { + log.Printf("Detected bruteforce attempt on \"%s\" from %s", action, client) + return doThrottle, ErrBruteforceDetected + } + } + + // Remove old entries. + newEntries := t.filterEntries(entries, now) + if newl := len(newEntries); newl == 0 { + t.setEntries(client, action, nil) + return doThrottle, nil + } else if newl != l { + t.setEntries(client, action, newEntries) + } + + return doThrottle, nil +} + +func (t *memoryThrottler) throttle(ctx context.Context, client string, action string, now time.Time) { + entry := throttleEntry{ + ts: now, + } + count := t.addEntry(client, action, entry) + delay := t.getDelay(count - 1) + log.Printf("Failed attempt on \"%s\" from %s, throttling by %s", action, client, delay) + t.doDelay(ctx, delay) +} + +func (t *memoryThrottler) delay(ctx context.Context, duration time.Duration) { + c, cancel := context.WithTimeout(ctx, duration) + defer cancel() + + <-c.Done() +} diff --git a/throttle_test.go b/throttle_test.go new file mode 100644 index 0000000..a60beb4 --- /dev/null +++ b/throttle_test.go @@ -0,0 +1,277 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "testing" + "time" +) + +func newMemoryThrottlerForTest(t *testing.T) *memoryThrottler { + t.Helper() + result, err := NewMemoryThrottler() + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + result.Close() + }) + + return result.(*memoryThrottler) +} + +type throttlerTiming struct { + t *testing.T + + now time.Time + expectedSleep time.Duration +} + +func (t *throttlerTiming) getNow() time.Time { + return t.now +} + +func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) { + t.t.Helper() + if duration != t.expectedSleep { + t.t.Errorf("expected sleep %s, got %s", t.expectedSleep, duration) + } +} + +func TestThrottler(t *testing.T) { + timing := &throttlerTiming{ + t: t, + now: time.Now(), + } + th := newMemoryThrottlerForTest(t) + th.getNow = timing.getNow + th.doDelay = timing.doDelay + + ctx := context.Background() + + throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle1(ctx) + + timing.now = timing.now.Add(time.Millisecond) + throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 200 * time.Millisecond + throttle2(ctx) + + timing.now = timing.now.Add(time.Millisecond) + throttle3, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle3(ctx) + + timing.now = timing.now.Add(time.Millisecond) + throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle4(ctx) +} + +func TestThrottler_Bruteforce(t *testing.T) { + timing := &throttlerTiming{ + t: t, + now: time.Now(), + } + th := newMemoryThrottlerForTest(t) + th.getNow = timing.getNow + th.doDelay = timing.doDelay + + ctx := context.Background() + + for i := 0; i < maxBruteforceAttempts; i++ { + timing.now = timing.now.Add(time.Millisecond) + throttle, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + if i == 0 { + timing.expectedSleep = 100 * time.Millisecond + } else { + timing.expectedSleep *= 2 + if timing.expectedSleep > maxThrottleDelay { + timing.expectedSleep = maxThrottleDelay + } + } + throttle(ctx) + } + + timing.now = timing.now.Add(time.Millisecond) + if _, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1"); err == nil { + t.Error("expected bruteforce error") + } else if err != ErrBruteforceDetected { + t.Errorf("expected error %s, got %s", ErrBruteforceDetected, err) + } +} + +func TestThrottler_Cleanup(t *testing.T) { + timing := &throttlerTiming{ + t: t, + now: time.Now(), + } + th := newMemoryThrottlerForTest(t) + th.getNow = timing.getNow + th.doDelay = timing.doDelay + + ctx := context.Background() + + throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle1(ctx) + + throttle2, err := th.CheckBruteforce(ctx, "192.168.0.2", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle2(ctx) + + timing.now = timing.now.Add(time.Hour) + throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action2") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle3(ctx) + + throttle4, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 200 * time.Millisecond + throttle4(ctx) + + timing.now = timing.now.Add(-time.Hour).Add(maxBruteforceAge).Add(time.Second) + th.cleanup(timing.now) + + if entries := th.getEntries("192.168.0.1", "action1"); len(entries) != 1 { + t.Errorf("should have removed one entry, got %+v", entries) + } + if entries := th.getEntries("192.168.0.1", "action2"); len(entries) != 1 { + t.Errorf("should have kept entry, got %+v", entries) + } + + th.mu.RLock() + if _, found := th.clients["192.168.0.2"]; found { + t.Error("should have removed client \"192.168.0.2\"") + } + th.mu.RUnlock() + + throttle5, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 200 * time.Millisecond + throttle5(ctx) +} + +func TestThrottler_ExpirePartial(t *testing.T) { + timing := &throttlerTiming{ + t: t, + now: time.Now(), + } + th := newMemoryThrottlerForTest(t) + th.getNow = timing.getNow + th.doDelay = timing.doDelay + + ctx := context.Background() + + throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle1(ctx) + + timing.now = timing.now.Add(time.Minute) + + throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 200 * time.Millisecond + throttle2(ctx) + + timing.now = timing.now.Add(maxBruteforceAge).Add(-time.Minute + time.Second) + + throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 200 * time.Millisecond + throttle3(ctx) +} + +func TestThrottler_ExpireAll(t *testing.T) { + timing := &throttlerTiming{ + t: t, + now: time.Now(), + } + th := newMemoryThrottlerForTest(t) + th.getNow = timing.getNow + th.doDelay = timing.doDelay + + ctx := context.Background() + + throttle1, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle1(ctx) + + timing.now = timing.now.Add(time.Millisecond) + + throttle2, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 200 * time.Millisecond + throttle2(ctx) + + timing.now = timing.now.Add(maxBruteforceAge).Add(time.Second) + + throttle3, err := th.CheckBruteforce(ctx, "192.168.0.1", "action1") + if err != nil { + t.Error(err) + } + timing.expectedSleep = 100 * time.Millisecond + throttle3(ctx) +} From 7f8e44b3b52eefc7cd60bd9ac953a9642ab52e1e Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 14 May 2024 12:02:36 +0200 Subject: [PATCH 2/5] Add bruteforce detection to backend server room handler. --- backend_server.go | 14 +++++++++++++- hub.go | 10 ++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/backend_server.go b/backend_server.go index a534ed5..309c21a 100644 --- a/backend_server.go +++ b/backend_server.go @@ -761,6 +761,16 @@ func (b *BackendServer) startDialout(roomid string, backend *Backend, backendUrl } func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body []byte) { + throttle, err := b.hub.throttler.CheckBruteforce(r.Context(), b.hub.getRealUserIP(r), "BackendRoomAuth") + if err == ErrBruteforceDetected { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } else if err != nil { + log.Printf("Error checking for bruteforce: %s", err) + http.Error(w, "Could not check for bruteforce", http.StatusInternalServerError) + return + } + v := mux.Vars(r) roomid := v["roomid"] @@ -773,6 +783,7 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body if backend == nil { // Unknown backend URL passed, return immediately. + throttle(r.Context()) http.Error(w, "Authentication check failed", http.StatusForbidden) return } @@ -794,12 +805,14 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body } if backend == nil { + throttle(r.Context()) http.Error(w, "Authentication check failed", http.StatusForbidden) return } } if !ValidateBackendChecksum(r, body, backend.Secret()) { + throttle(r.Context()) http.Error(w, "Authentication check failed", http.StatusForbidden) return } @@ -814,7 +827,6 @@ func (b *BackendServer) roomHandler(w http.ResponseWriter, r *http.Request, body request.ReceivedTime = time.Now().UnixNano() var response any - var err error switch request.Type { case "invite": b.sendRoomInvite(roomid, backend, request.Invite.UserIds, request.Invite.Properties) diff --git a/hub.go b/hub.go index 77a86c7..fd80ffc 100644 --- a/hub.go +++ b/hub.go @@ -173,6 +173,8 @@ type Hub struct { rpcServer *GrpcServer rpcClients *GrpcClients + + throttler Throttler } func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer, rpcClients *GrpcClients, etcdClient *EtcdClient, r *mux.Router, version string) (*Hub, error) { @@ -328,6 +330,11 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer } } + throttler, err := NewMemoryThrottler() + if err != nil { + return nil, err + } + hub := &Hub{ events: events, upgrader: websocket.Upgrader{ @@ -376,6 +383,8 @@ func NewHub(config *goconf.ConfigFile, events AsyncEvents, rpcServer *GrpcServer rpcServer: rpcServer, rpcClients: rpcClients, + + throttler: throttler, } hub.setWelcomeMessage(&ServerMessage{ Type: "welcome", @@ -498,6 +507,7 @@ loop: func (h *Hub) Stop() { h.closer.Close() + h.throttler.Close() } func (h *Hub) Reload(config *goconf.ConfigFile) { From 39f4b2eb112b8b885f3dd11e2b22ad469413f6fe Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 14 May 2024 12:03:40 +0200 Subject: [PATCH 3/5] server: Increase default write timeout so delayed responses can be sent out. --- server.conf.in | 4 ++-- server/main.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server.conf.in b/server.conf.in index 43289b1..6b61b7b 100644 --- a/server.conf.in +++ b/server.conf.in @@ -7,7 +7,7 @@ #readtimeout = 15 # HTTP socket write timeout in seconds. -#writetimeout = 15 +#writetimeout = 30 [https] # IP and port to listen on for HTTPS requests. @@ -18,7 +18,7 @@ #readtimeout = 15 # HTTPS socket write timeout in seconds. -#writetimeout = 15 +#writetimeout = 30 # Certificate / private key to use for the HTTPS server. certificate = /etc/nginx/ssl/server.crt diff --git a/server/main.go b/server/main.go index a31a0f5..2a5dc60 100644 --- a/server/main.go +++ b/server/main.go @@ -61,7 +61,7 @@ var ( const ( defaultReadTimeout = 15 - defaultWriteTimeout = 15 + defaultWriteTimeout = 30 initialMcuRetry = time.Second maxMcuRetry = time.Second * 16 From e862392872555a851f124450eeb269bd7f58b8ed Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 14 May 2024 13:40:19 +0200 Subject: [PATCH 4/5] Add throttled requests to metrics. --- docs/prometheus-metrics.md | 2 ++ throttle.go | 7 +++++ throttle_stats_prometheus.go | 51 ++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 throttle_stats_prometheus.go diff --git a/docs/prometheus-metrics.md b/docs/prometheus-metrics.md index 8e54db5..5655f5d 100644 --- a/docs/prometheus-metrics.md +++ b/docs/prometheus-metrics.md @@ -49,3 +49,5 @@ The following metrics are available: | `signaling_grpc_client_calls_total` | Counter | 1.0.0 | The total number of GRPC client calls | `method` | | `signaling_grpc_server_calls_total` | Counter | 1.0.0 | The total number of GRPC server calls | `method` | | `signaling_http_client_pool_connections` | Gauge | 1.2.4 | The current number of HTTP client connections per host | `host` | +| `signaling_throttle_delayed_total` | Counter | 1.2.5 | The total number of delayed requests | `action`, `delay` | +| `signaling_throttle_bruteforce_total` | Counter | 1.2.5 | The total number of rejected bruteforce requests | `action` | diff --git a/throttle.go b/throttle.go index af0abde..43332be 100644 --- a/throttle.go +++ b/throttle.go @@ -25,6 +25,7 @@ import ( "context" "errors" "log" + "strconv" "sync" "time" ) @@ -54,6 +55,10 @@ var ( ErrBruteforceDetected = errors.New("bruteforce detected") ) +func init() { + RegisterThrottleStats() +} + type ThrottleFunc func(ctx context.Context) type Throttler interface { @@ -248,6 +253,7 @@ func (t *memoryThrottler) CheckBruteforce(ctx context.Context, client string, ac delta := now.Sub(entries[l-maxBruteforceAttempts].ts) if delta <= maxBruteforceDurationThreshold { log.Printf("Detected bruteforce attempt on \"%s\" from %s", action, client) + statsThrottleBruteforceTotal.WithLabelValues(action).Inc() return doThrottle, ErrBruteforceDetected } } @@ -271,6 +277,7 @@ func (t *memoryThrottler) throttle(ctx context.Context, client string, action st count := t.addEntry(client, action, entry) delay := t.getDelay(count - 1) log.Printf("Failed attempt on \"%s\" from %s, throttling by %s", action, client, delay) + statsThrottleDelayedTotal.WithLabelValues(action, strconv.FormatInt(delay.Milliseconds(), 10)).Inc() t.doDelay(ctx, delay) } diff --git a/throttle_stats_prometheus.go b/throttle_stats_prometheus.go new file mode 100644 index 0000000..8279fe0 --- /dev/null +++ b/throttle_stats_prometheus.go @@ -0,0 +1,51 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +var ( + statsThrottleDelayedTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "signaling", + Subsystem: "throttle", + Name: "delayed_total", + Help: "The total number of delayed requests", + }, []string{"action", "delay"}) + + statsThrottleBruteforceTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "signaling", + Subsystem: "throttle", + Name: "bruteforce_total", + Help: "The total number of rejected bruteforce requests", + }, []string{"action"}) + + throttleStats = []prometheus.Collector{ + statsThrottleDelayedTotal, + statsThrottleBruteforceTotal, + } +) + +func RegisterThrottleStats() { + registerAll(throttleStats...) +} From 4c807c86e83dbcb2958b9d4da0ad75d0d9417588 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 14 May 2024 14:25:08 +0200 Subject: [PATCH 5/5] Throttle resume / internal hello. --- hub.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/hub.go b/hub.go index fd80ffc..08ccac6 100644 --- a/hub.go +++ b/hub.go @@ -63,6 +63,7 @@ var ( NoSuchSession = NewError("no_such_session", "The session to resume does not exist.") TokenNotValidYet = NewError("token_not_valid_yet", "The token is not valid yet.") TokenExpired = NewError("token_expired", "The token is expired.") + TooManyRequests = NewError("too_many_requests", "Too many requests.") // Maximum number of concurrent requests to a backend. defaultMaxConcurrentRequestsPerHost = 8 @@ -1134,8 +1135,19 @@ func (h *Hub) tryProxyResume(c HandlerClient, resumeId string, message *ClientMe } func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { + ctx := context.TODO() resumeId := message.Hello.ResumeId if resumeId != "" { + throttle, err := h.throttler.CheckBruteforce(ctx, client.RemoteAddr(), "HelloResume") + if err == ErrBruteforceDetected { + client.SendMessage(message.NewErrorServerMessage(TooManyRequests)) + return + } else if err != nil { + log.Printf("Error checking for bruteforce: %s", err) + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + data := h.decodeSessionId(resumeId, privateSessionName) if data == nil { statsHubSessionResumeFailed.Inc() @@ -1143,6 +1155,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { return } + throttle(ctx) client.SendMessage(message.NewErrorServerMessage(NoSuchSession)) return } @@ -1156,6 +1169,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) { return } + throttle(ctx) client.SendMessage(message.NewErrorServerMessage(NoSuchSession)) return } @@ -1376,18 +1390,31 @@ func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage) return } + ctx := context.TODO() + throttle, err := h.throttler.CheckBruteforce(ctx, client.RemoteAddr(), "HelloInternal") + if err == ErrBruteforceDetected { + client.SendMessage(message.NewErrorServerMessage(TooManyRequests)) + return + } else if err != nil { + log.Printf("Error checking for bruteforce: %s", err) + client.SendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + // Validate internal connection. rnd := message.Hello.Auth.internalParams.Random mac := hmac.New(sha256.New, h.internalClientsSecret) mac.Write([]byte(rnd)) // nolint check := hex.EncodeToString(mac.Sum(nil)) if len(rnd) < minTokenRandomLength || check != message.Hello.Auth.internalParams.Token { + throttle(ctx) client.SendMessage(message.NewErrorServerMessage(InvalidToken)) return } backend := h.backend.GetBackend(message.Hello.Auth.internalParams.parsedBackend) if backend == nil { + throttle(ctx) client.SendMessage(message.NewErrorServerMessage(InvalidBackendUrl)) return }