Move throttler code to async package.

This commit is contained in:
Joachim Bauch 2025-12-10 16:05:14 +01:00
commit 1c3a03e972
No known key found for this signature in database
GPG key ID: 77C1D22D53E15F02
6 changed files with 44 additions and 36 deletions

View file

@ -19,7 +19,7 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
package async
import (
"context"
@ -92,9 +92,12 @@ type throttleEntry struct {
ts time.Time
}
type GetTimeFunc func() time.Time
type ThrottleDelayFunc func(context.Context, time.Duration)
type memoryThrottler struct {
getNow func() time.Time
doDelay func(context.Context, time.Duration)
getNow GetTimeFunc
doDelay ThrottleDelayFunc
mu sync.RWMutex
// +checklocks:mu
@ -104,14 +107,18 @@ type memoryThrottler struct {
}
func NewMemoryThrottler() (Throttler, error) {
return NewCustomMemoryThrottler(time.Now, defaultDelay)
}
func NewCustomMemoryThrottler(getNow GetTimeFunc, delay ThrottleDelayFunc) (Throttler, error) {
result := &memoryThrottler{
getNow: time.Now,
getNow: getNow,
doDelay: delay,
clients: make(map[string]map[string][]throttleEntry),
closer: internal.NewCloser(),
}
result.doDelay = result.delay
go result.housekeeping()
return result, nil
}
@ -310,7 +317,7 @@ func (t *memoryThrottler) throttle(ctx context.Context, client string, action st
t.doDelay(ctx, delay)
}
func (t *memoryThrottler) delay(ctx context.Context, duration time.Duration) {
func defaultDelay(ctx context.Context, duration time.Duration) {
c, cancel := context.WithTimeout(ctx, duration)
defer cancel()

View file

@ -19,7 +19,7 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
package async
import (
"github.com/prometheus/client_golang/prometheus"

View file

@ -19,10 +19,9 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package signaling
package async
import (
"context"
"testing"
"time"
@ -45,22 +44,6 @@ func newMemoryThrottlerForTest(t *testing.T) Throttler {
return result
}
type throttlerTiming struct {
t *testing.T
now time.Time
expectedSleep time.Duration
}
func (t *throttlerTiming) getNow() time.Time {
return t.now
}
func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) {
t.t.Helper()
assert.Equal(t.t, t.expectedSleep, duration)
}
func expectDelay(t *testing.T, f func(), delay time.Duration) {
t.Helper()
a := time.Now()

View file

@ -49,6 +49,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/strukturag/nextcloud-spreed-signaling/api"
"github.com/strukturag/nextcloud-spreed-signaling/async"
"github.com/strukturag/nextcloud-spreed-signaling/log"
)
@ -840,7 +841,7 @@ func (b *BackendServer) startDialout(ctx context.Context, roomid string, backend
func (b *BackendServer) roomHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, body []byte) {
throttle, err := b.hub.throttler.CheckBruteforce(ctx, b.hub.getRealUserIP(r), "BackendRoomAuth")
if err == ErrBruteforceDetected {
if err == async.ErrBruteforceDetected {
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
} else if err != nil {

8
hub.go
View file

@ -213,7 +213,7 @@ type Hub struct {
rpcServer *GrpcServer
rpcClients *GrpcClients
throttler Throttler
throttler async.Throttler
skipFederationVerify bool
federationTimeout time.Duration
@ -352,7 +352,7 @@ func NewHub(ctx context.Context, config *goconf.ConfigFile, events AsyncEvents,
return nil, err
}
throttler, err := NewMemoryThrottler()
throttler, err := async.NewMemoryThrottler()
if err != nil {
return nil, err
}
@ -1277,7 +1277,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
resumeId := message.Hello.ResumeId
if resumeId != "" {
throttle, err := h.throttler.CheckBruteforce(ctx, client.RemoteAddr(), "HelloResume")
if err == ErrBruteforceDetected {
if err == async.ErrBruteforceDetected {
client.SendMessage(message.NewErrorServerMessage(TooManyRequests))
return
} else if err != nil {
@ -1581,7 +1581,7 @@ func (h *Hub) processHelloInternal(client HandlerClient, message *ClientMessage)
ctx := log.NewLoggerContext(client.Context(), h.logger)
throttle, err := h.throttler.CheckBruteforce(ctx, client.RemoteAddr(), "HelloInternal")
if err == ErrBruteforceDetected {
if err == async.ErrBruteforceDetected {
client.SendMessage(message.NewErrorServerMessage(TooManyRequests))
return
} else if err != nil {

View file

@ -53,6 +53,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/strukturag/nextcloud-spreed-signaling/api"
"github.com/strukturag/nextcloud-spreed-signaling/async"
"github.com/strukturag/nextcloud-spreed-signaling/internal"
"github.com/strukturag/nextcloud-spreed-signaling/log"
"github.com/strukturag/nextcloud-spreed-signaling/test"
@ -1440,6 +1441,22 @@ func TestClientHelloResume(t *testing.T) {
}
}
type throttlerTiming struct {
t *testing.T
now time.Time
expectedSleep time.Duration
}
func (t *throttlerTiming) getNow() time.Time {
return t.now
}
func (t *throttlerTiming) doDelay(ctx context.Context, duration time.Duration) {
t.t.Helper()
assert.Equal(t.t, t.expectedSleep, duration)
}
func TestClientHelloResumeThrottle(t *testing.T) {
t.Parallel()
require := require.New(t)
@ -1450,12 +1467,12 @@ func TestClientHelloResumeThrottle(t *testing.T) {
t: t,
now: time.Now(),
}
throttler := newMemoryThrottlerForTest(t)
th, ok := throttler.(*memoryThrottler)
require.True(ok, "expected memoryThrottler, got %T", throttler)
th.getNow = timing.getNow
th.doDelay = timing.doDelay
hub.throttler = th
throttler, err := async.NewCustomMemoryThrottler(timing.getNow, timing.doDelay)
require.NoError(err)
t.Cleanup(func() {
throttler.Close()
})
hub.throttler = throttler
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()